From 03ee67bfd34b9e872b33eb05fef5db83410b16f3 Mon Sep 17 00:00:00 2001 From: WDevelopsWebApps <97454358+WDevelopsWebApps@users.noreply.github.com> Date: Wed, 28 Sep 2022 10:53:40 +0200 Subject: add advanced saving for save button --- modules/images.py | 5 ++++- modules/ui.py | 37 +++++++++++++++++++++++++++++-------- 2 files changed, 33 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 9458bf8d..923f81df 100644 --- a/modules/images.py +++ b/modules/images.py @@ -290,7 +290,10 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[cfg]", str(p.cfg_scale)) x = x.replace("[width]", str(p.width)) x = x.replace("[height]", str(p.height)) - x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False)) + #currently disabled if using the save button, will work otherwise + # if enabled it will cause a bug because styles is not included in the save_files data dictionary + if hasattr(p, "styles"): + x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False)) x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) diff --git a/modules/ui.py b/modules/ui.py index 7db8edbd..87a86a45 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -28,6 +28,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste +from modules.images import apply_filename_pattern, get_next_sequence_number # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() @@ -90,13 +91,26 @@ def send_gradio_gallery_to_image(x): def save_files(js_data, images, index): - import csv - - os.makedirs(opts.outdir_save, exist_ok=True) - + import csv filenames = [] + #quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + data = json.loads(js_data) + p = MyObject(data) + path = opts.outdir_save + save_to_dirs = opts.save_to_dirs + + if save_to_dirs: + dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, p.seed, p.prompt) + path = os.path.join(opts.outdir_save, dirname) + + os.makedirs(path, exist_ok=True) if index > -1 and opts.save_selected_only and (index > 0 or not opts.return_grid): # ensures we are looking at a specific non-grid picture, and we have save_selected_only images = [images[index]] @@ -107,11 +121,18 @@ def save_files(js_data, images, index): writer = csv.writer(file) if at_start: writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - filename_base = str(int(time.time() * 1000)) + file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]" + if file_decoration != "": + file_decoration = "-" + file_decoration.lower() + file_decoration = apply_filename_pattern(file_decoration, p, p.seed, p.prompt) + truncated = (file_decoration[:240] + '..') if len(file_decoration) > 240 else file_decoration + filename_base = truncated + + basecount = get_next_sequence_number(path, "") for i, filedata in enumerate(images): - filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + ".png" - filepath = os.path.join(opts.outdir_save, filename) + file_number = f"{basecount+i:05}" + filename = file_number + filename_base + ".png" + filepath = os.path.join(path, filename) if filedata.startswith("data:image/png;base64,"): filedata = filedata[len("data:image/png;base64,"):] -- cgit v1.2.3 From c938679de7b87b4f14894d9f57fe0f40dd6e3c06 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Wed, 28 Sep 2022 22:14:13 -0300 Subject: Fix memory leak and reduce memory usage --- modules/codeformer_model.py | 6 ++++-- modules/devices.py | 3 ++- modules/extras.py | 2 ++ modules/gfpgan_model.py | 11 +++++------ modules/processing.py | 33 ++++++++++++++++++++++++++------- 5 files changed, 39 insertions(+), 16 deletions(-) (limited to 'modules') diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 8fbdea24..2177291a 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -89,7 +89,7 @@ def setup_codeformer(): output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) del output - torch.cuda.empty_cache() + devices.torch_gc() except Exception as error: print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) @@ -106,7 +106,9 @@ def setup_codeformer(): restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) if shared.opts.face_restoration_unload: - self.net.to(devices.cpu) + self.net = None + self.face_helper = None + devices.torch_gc() return restored_img diff --git a/modules/devices.py b/modules/devices.py index 07bb2339..df63dd88 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,4 +1,5 @@ import torch +import gc # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility from modules import errors @@ -17,8 +18,8 @@ def get_optimal_device(): return cpu - def torch_gc(): + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/modules/extras.py b/modules/extras.py index 9a825530..38b86167 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -98,6 +98,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v outputs.append(image) + devices.torch_gc() + return outputs, plaintext_to_html(info), '' diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 44c5dc6c..b1288f0c 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -49,6 +49,7 @@ def gfpgan(): def gfpgan_fix_faces(np_image): + global loaded_gfpgan_model model = gfpgan() np_image_bgr = np_image[:, :, ::-1] @@ -56,7 +57,9 @@ def gfpgan_fix_faces(np_image): np_image = gfpgan_output_bgr[:, :, ::-1] if shared.opts.face_restoration_unload: - model.gfpgan.to(devices.cpu) + del model + loaded_gfpgan_model = None + devices.torch_gc() return np_image @@ -83,11 +86,7 @@ def setup_gfpgan(): return "GFPGAN" def restore(self, np_image): - np_image_bgr = np_image[:, :, ::-1] - cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) - np_image = gfpgan_output_bgr[:, :, ::-1] - - return np_image + return gfpgan_fix_faces(np_image) shared.face_restorers.append(FaceRestorerGFPGAN()) except Exception: diff --git a/modules/processing.py b/modules/processing.py index 4ecdfcd2..de5cda79 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ import cv2 from skimage import exposure import modules.sd_hijack -from modules import devices, prompt_parser, masking +from modules import devices, prompt_parser, masking, lowvram from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state @@ -335,7 +335,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter - for n in range(p.n_iter): + for n in range(p.n_iter): + with torch.no_grad(), precision_scope("cuda"), ema_scope(): if state.interrupted: break @@ -368,22 +369,32 @@ def process_images(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + del samples_ddim + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + + devices.torch_gc() + if opts.filter_nsfw: import modules.safety as safety x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) - for i, x_sample in enumerate(x_samples_ddim): + for i, x_sample in enumerate(x_samples_ddim): + with torch.no_grad(), precision_scope("cuda"), ema_scope(): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) - if p.restore_faces: + if p.restore_faces: + with torch.no_grad(), precision_scope("cuda"), ema_scope(): if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") - devices.torch_gc() - x_sample = modules.face_restoration.restore_faces(x_sample) + devices.torch_gc() + + with torch.no_grad(), precision_scope("cuda"), ema_scope(): image = Image.fromarray(x_sample) if p.color_corrections is not None and i < len(p.color_corrections): @@ -411,8 +422,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts.append(infotext(n, i)) output_images.append(image) - state.nextjob() + del x_samples_ddim + devices.torch_gc() + + state.nextjob() + + with torch.no_grad(), precision_scope("cuda"), ema_scope(): p.color_corrections = None index_of_first_image = 0 @@ -648,4 +664,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask + del x + devices.torch_gc() + return samples -- cgit v1.2.3 From c2d5b29040132c171bc4d77f1f63da972306f22c Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Thu, 29 Sep 2022 01:14:54 -0300 Subject: Move silu to sd_hijack --- modules/sd_hijack.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bfbd07f9..4bc58fa2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -12,6 +12,7 @@ from ldm.util import default from einops import rearrange import ldm.modules.attention import ldm.modules.diffusionmodules.model +from torch.nn.functional import silu # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion @@ -100,14 +101,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def nonlinearity_hijack(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -245,11 +238,12 @@ class StableDiffusionModelHijack: m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.clip = m.cond_stage_model + ldm.modules.diffusionmodules.model.nonlinearity = silu + if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward def flatten(el): -- cgit v1.2.3 From e82ea202997cbcd2ab72891cd075d9ba270eb67d Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Fri, 30 Sep 2022 15:26:18 -0500 Subject: Optimize model loader Child classes only get populated to __subclassess__ when they are imported. We don't actually need to import any of them to webui any more, so clean up webUI imports and make sure loader imports children. Also, fix command line paths not actually being passed to the scalers. --- modules/modelloader.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/modelloader.py b/modules/modelloader.py index 1106aeb7..b1721671 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -4,7 +4,6 @@ import importlib from urllib.parse import urlparse from basicsr.utils.download_util import load_file_from_url - from modules import shared from modules.upscaler import Upscaler from modules.paths import script_path, models_path @@ -120,16 +119,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): def load_upscalers(): + sd = shared.script_path + # We can only do this 'magic' method to dynamically load upscalers if they are referenced, + # so we'll try to import any _model.py files before looking in __subclasses__ + modules_dir = os.path.join(sd, "modules") + for file in os.listdir(modules_dir): + if "_model.py" in file: + model_name = file.replace("_model.py", "") + full_model = f"modules.{model_name}_model" + try: + importlib.import_module(full_model) + except: + pass datas = [] + c_o = vars(shared.cmd_opts) for cls in Upscaler.__subclasses__(): name = cls.__name__ module_name = cls.__module__ module = importlib.import_module(module_name) class_ = getattr(module, name) - cmd_name = f"{name.lower().replace('upscaler', '')}-models-path" + cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" opt_string = None try: - opt_string = shared.opts.__getattr__(cmd_name) + if cmd_name in c_o: + opt_string = c_o[cmd_name] except: pass scaler = class_(opt_string) -- cgit v1.2.3 From 8deae077004f0332ca607fc3a5d568b1a4705bec Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Fri, 30 Sep 2022 15:28:37 -0500 Subject: Add ScuNET DeNoiser/Upscaler Q&D Implementation of ScuNET, thanks to our handy model loader. :P https://github.com/cszn/SCUNet --- modules/scunet_model.py | 90 +++++++++++++++ modules/scunet_model_arch.py | 265 +++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + 3 files changed, 356 insertions(+) create mode 100644 modules/scunet_model.py create mode 100644 modules/scunet_model_arch.py (limited to 'modules') diff --git a/modules/scunet_model.py b/modules/scunet_model.py new file mode 100644 index 00000000..7987ac14 --- /dev/null +++ b/modules/scunet_model.py @@ -0,0 +1,90 @@ +import os.path +import sys +import traceback + +import PIL.Image +import numpy as np +import torch +from basicsr.utils.download_util import load_file_from_url + +import modules.upscaler +from modules import shared, modelloader +from modules.paths import models_path +from modules.scunet_model_arch import SCUNet as net + + +class UpscalerScuNET(modules.upscaler.Upscaler): + def __init__(self, dirname): + self.name = "ScuNET" + self.model_path = os.path.join(models_path, self.name) + self.model_name = "ScuNET GAN" + self.model_name2 = "ScuNET PSNR" + self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" + self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth" + self.user_path = dirname + super().__init__() + model_paths = self.find_models(ext_filter=[".pth"]) + scalers = [] + add_model2 = True + for file in model_paths: + if "http" in file: + name = self.model_name + else: + name = modelloader.friendly_name(file) + if name == self.model_name2 or file == self.model_url2: + add_model2 = False + try: + scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) + scalers.append(scaler_data) + except Exception: + print(f"Error loading ScuNET model: {file}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + if add_model2: + scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) + scalers.append(scaler_data2) + self.scalers = scalers + + def do_upscale(self, img: PIL.Image, selected_file): + torch.cuda.empty_cache() + + model = self.load_model(selected_file) + if model is None: + return img + + device = shared.device + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(shared.device) + + img = img.to(device) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + torch.cuda.empty_cache() + return PIL.Image.fromarray(output, 'RGB') + + def load_model(self, path: str): + device = shared.device + if "http" in path: + filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, + progress=True) + else: + filename = path + if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: + print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) + return None + + model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) + model.load_state_dict(torch.load(filename), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + model = model.to(device) + + return model + diff --git a/modules/scunet_model_arch.py b/modules/scunet_model_arch.py new file mode 100644 index 00000000..972a2639 --- /dev/null +++ b/modules/scunet_model_arch.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_, DropPath + + +class WMSA(nn.Module): + """ Self-attention module in Swin Transformer + """ + + def __init__(self, input_dim, output_dim, head_dim, window_size, type): + super(WMSA, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.head_dim = head_dim + self.scale = self.head_dim ** -0.5 + self.n_heads = input_dim // head_dim + self.window_size = window_size + self.type = type + self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) + + self.relative_position_params = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) + + self.linear = nn.Linear(self.input_dim, self.output_dim) + + trunc_normal_(self.relative_position_params, std=.02) + self.relative_position_params = torch.nn.Parameter( + self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, + 2).transpose( + 0, 1)) + + def generate_mask(self, h, w, p, shift): + """ generating the mask of SW-MSA + Args: + shift: shift parameters in CyclicShift. + Returns: + attn_mask: should be (1 1 w p p), + """ + # supporting sqaure. + attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) + if self.type == 'W': + return attn_mask + + s = p - shift + attn_mask[-1, :, :s, :, s:, :] = True + attn_mask[-1, :, s:, :, :s, :] = True + attn_mask[:, -1, :, :s, :, s:] = True + attn_mask[:, -1, :, s:, :, :s] = True + attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') + return attn_mask + + def forward(self, x): + """ Forward pass of Window Multi-head Self-attention module. + Args: + x: input tensor with shape of [b h w c]; + attn_mask: attention mask, fill -inf where the value is True; + Returns: + output: tensor shape [b h w c] + """ + if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) + h_windows = x.size(1) + w_windows = x.size(2) + # sqaure validation + # assert h_windows == w_windows + + x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) + qkv = self.embedding_layer(x) + q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) + sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale + # Adding learnable relative embedding + sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') + # Using Attn Mask to distinguish different subwindows. + if self.type != 'W': + attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2) + sim = sim.masked_fill_(attn_mask, float("-inf")) + + probs = nn.functional.softmax(sim, dim=-1) + output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) + output = rearrange(output, 'h b w p c -> b w p (h c)') + output = self.linear(output) + output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) + + if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), + dims=(1, 2)) + return output + + def relative_embedding(self): + cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) + relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 + # negative is allowed + return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()] + + +class Block(nn.Module): + def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): + """ SwinTransformer Block + """ + super(Block, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + assert type in ['W', 'SW'] + self.type = type + if input_resolution <= window_size: + self.type = 'W' + + self.ln1 = nn.LayerNorm(input_dim) + self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ln2 = nn.LayerNorm(input_dim) + self.mlp = nn.Sequential( + nn.Linear(input_dim, 4 * input_dim), + nn.GELU(), + nn.Linear(4 * input_dim, output_dim), + ) + + def forward(self, x): + x = x + self.drop_path(self.msa(self.ln1(x))) + x = x + self.drop_path(self.mlp(self.ln2(x))) + return x + + +class ConvTransBlock(nn.Module): + def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): + """ SwinTransformer and Conv Block + """ + super(ConvTransBlock, self).__init__() + self.conv_dim = conv_dim + self.trans_dim = trans_dim + self.head_dim = head_dim + self.window_size = window_size + self.drop_path = drop_path + self.type = type + self.input_resolution = input_resolution + + assert self.type in ['W', 'SW'] + if self.input_resolution <= self.window_size: + self.type = 'W' + + self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, + self.type, self.input_resolution) + self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) + self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) + + self.conv_block = nn.Sequential( + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), + nn.ReLU(True), + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) + ) + + def forward(self, x): + conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) + conv_x = self.conv_block(conv_x) + conv_x + trans_x = Rearrange('b c h w -> b h w c')(trans_x) + trans_x = self.trans_block(trans_x) + trans_x = Rearrange('b h w c -> b c h w')(trans_x) + res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) + x = x + res + + return x + + +class SCUNet(nn.Module): + # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): + def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): + super(SCUNet, self).__init__() + if config is None: + config = [2, 2, 2, 2, 2, 2, 2] + self.config = config + self.dim = dim + self.head_dim = 32 + self.window_size = 8 + + # drop path rate for each layer + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] + + self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] + + begin = 0 + self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution) + for i in range(config[0])] + \ + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] + + begin += config[0] + self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 2) + for i in range(config[1])] + \ + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] + + begin += config[1] + self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 4) + for i in range(config[2])] + \ + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] + + begin += config[2] + self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 8) + for i in range(config[3])] + + begin += config[3] + self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 4) + for i in range(config[4])] + + begin += config[4] + self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution // 2) + for i in range(config[5])] + + begin += config[5] + self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \ + [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], + 'W' if not i % 2 else 'SW', input_resolution) + for i in range(config[6])] + + self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] + + self.m_head = nn.Sequential(*self.m_head) + self.m_down1 = nn.Sequential(*self.m_down1) + self.m_down2 = nn.Sequential(*self.m_down2) + self.m_down3 = nn.Sequential(*self.m_down3) + self.m_body = nn.Sequential(*self.m_body) + self.m_up3 = nn.Sequential(*self.m_up3) + self.m_up2 = nn.Sequential(*self.m_up2) + self.m_up1 = nn.Sequential(*self.m_up1) + self.m_tail = nn.Sequential(*self.m_tail) + # self.apply(self._init_weights) + + def forward(self, x0): + + h, w = x0.size()[-2:] + paddingBottom = int(np.ceil(h / 64) * 64 - h) + paddingRight = int(np.ceil(w / 64) * 64 - w) + x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) + + x1 = self.m_head(x0) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body(x4) + x = self.m_up3(x + x4) + x = self.m_up2(x + x3) + x = self.m_up1(x + x2) + x = self.m_tail(x + x1) + + x = x[..., :h, :w] + + return x + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index 8428c7a3..a48b995a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -40,6 +40,7 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) +parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(model_path, 'ScuNET')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR')) parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") -- cgit v1.2.3 From abdbf1de646f007b6d76cfb3f416fdfaadb57903 Mon Sep 17 00:00:00 2001 From: Liam Date: Thu, 29 Sep 2022 14:40:47 -0400 Subject: token counters now update when roll artist and style buttons are pressed https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/1194#issuecomment-1261203893 --- modules/ui.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 15572bb0..5eea1860 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -539,6 +539,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): roll.click( fn=roll_artist, + _js="roll_artist_txt2img", inputs=[ txt2img_prompt, ], @@ -743,6 +744,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): roll.click( fn=roll_artist, + _js="roll_artist_img2img", inputs=[ img2img_prompt, ], @@ -753,6 +755,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_js_funcs = ["update_style_txt2img", "update_style_img2img"] for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): button.click( @@ -764,9 +767,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], ) - for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns): + for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): button.click( fn=apply_styles, + _js=js_func, inputs=[prompt, negative_prompt, style1, style2], outputs=[prompt, negative_prompt, style1, style2], ) -- cgit v1.2.3 From ff8dc1908af088d0ed43fb85baad662733c5ca9c Mon Sep 17 00:00:00 2001 From: Liam Date: Thu, 29 Sep 2022 15:47:06 -0400 Subject: fixed token counter for prompt editing --- modules/ui.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 5eea1860..6bf28562 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -11,6 +11,7 @@ import time import traceback import platform import subprocess as sp +from functools import reduce import numpy as np import torch @@ -32,6 +33,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste +from modules.prompt_parser import get_learned_conditioning_prompt_schedules # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() @@ -345,8 +347,11 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: outputs=[seed, dummy_component] ) -def update_token_counter(text): - tokens, token_count, max_length = model_hijack.tokenize(text) +def update_token_counter(text, steps): + prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step,prompt_text in flat_prompts] + tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) style_class = ' class="red"' if (token_count > max_length) else "" return f"{token_count}/{max_length}" @@ -364,8 +369,7 @@ def create_toprow(is_img2img): roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) paste = gr.Button(value=paste_symbol, elem_id="paste") token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter]) + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") with gr.Column(scale=10, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) @@ -396,7 +400,7 @@ def create_toprow(is_img2img): prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") save_style = gr.Button('Create style', elem_id="style_create") - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button def setup_progressbar(progressbar, preview, id_part): @@ -419,7 +423,7 @@ def setup_progressbar(progressbar, preview, id_part): def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) with gr.Row(elem_id='txt2img_progress_row'): @@ -568,9 +572,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), ] modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) + token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True) + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): with gr.Column(scale=1): @@ -793,6 +798,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): (denoising_strength, "Denoising strength"), ] modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt) + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Row().style(equal_height=False): -- cgit v1.2.3 From 3c6a049fc3c6b54ada3736710a7e86663ea7f3d9 Mon Sep 17 00:00:00 2001 From: Liam Date: Fri, 30 Sep 2022 12:12:44 -0400 Subject: consolidated token counter functions --- modules/ui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 6bf28562..40c08984 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -543,7 +543,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): roll.click( fn=roll_artist, - _js="roll_artist_txt2img", + _js="update_txt2img_tokens", inputs=[ txt2img_prompt, ], @@ -749,7 +749,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): roll.click( fn=roll_artist, - _js="roll_artist_img2img", + _js="update_img2img_tokens", inputs=[ img2img_prompt, ], @@ -760,7 +760,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_style_txt2img", "update_style_img2img"] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): button.click( -- cgit v1.2.3 From bdaa36c84470adbdce3e98c01a69af5e95adfb02 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 30 Sep 2022 23:53:25 -0400 Subject: When device is MPS, use CPU for GFPGAN instead GFPGAN will not work if the device is MPS, so default to CPU instead. --- modules/devices.py | 2 +- modules/gfpgan_model.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 07bb2339..08bb26d6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,7 +34,7 @@ errors.run(enable_tf32, "Enabling TF32") device = get_optimal_device() -device_codeformer = cpu if has_mps else device +device_gfpgan = device_codeformer = cpu if device.type == 'mps' else device def randn(seed, shape): diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index bb30d733..fcd8544a 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -21,7 +21,7 @@ def gfpgann(): global loaded_gfpgan_model global model_path if loaded_gfpgan_model is not None: - loaded_gfpgan_model.gfpgan.to(shared.device) + loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) return loaded_gfpgan_model if gfpgan_constructor is None: @@ -36,8 +36,8 @@ def gfpgann(): else: print("Unable to load gfpgan model!") return None - model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) - model.gfpgan.to(shared.device) + model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) + model.gfpgan.to(devices.device_gfpgan) loaded_gfpgan_model = model return model -- cgit v1.2.3 From 4c2478a68a4f11959fe4887d38e0436eac19f97e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 1 Oct 2022 18:30:53 +0100 Subject: add script reload method --- modules/scripts.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 7c3bd5e7..3c14b9e3 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -165,3 +165,12 @@ class ScriptRunner: scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + +def reload_scripts(basedir): + global scripts_txt2img,scripts_img2img + + scripts_data.clear() + load_scripts(basedir) + + scripts_txt2img = ScriptRunner() + scripts_img2img = ScriptRunner() -- cgit v1.2.3 From 4f8490cd5630823ac44de8b5c5e4325bdbbea7fa Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 1 Oct 2022 18:33:31 +0100 Subject: add restart button --- modules/ui.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 15572bb0..ec6aaa28 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1002,6 +1002,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): _js='function(){}' ) + def request_restart(): + settings_interface.gradio_ref.do_restart = True + + restart_gradio = gr.Button(value='Restart Gradio and Refresh Scripts') + restart_gradio.click( + fn=request_restart, + inputs=[], + outputs=[], + _js='function(){document.body.innerHTML=\'

Reloading

\';setTimeout(function(){location.reload()},2000)}' + ) + if column is not None: column.__exit__() @@ -1026,7 +1037,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): css += css_hide_progressbar with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - + + settings_interface.gradio_ref = demo + with gr.Tabs() as tabs: for interface, label, ifid in interfaces: with gr.TabItem(label, id=ifid): -- cgit v1.2.3 From 121ed7d36febe94995774973b5edc1ba2ba84aad Mon Sep 17 00:00:00 2001 From: Alexandre Simard Date: Sat, 1 Oct 2022 14:04:20 -0400 Subject: Add progress bar for SwinIR in cmd I do not know how to add them to the UI... --- modules/swinir_model.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/swinir_model.py b/modules/swinir_model.py index 41fda5a7..9bd454c6 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -5,6 +5,7 @@ import numpy as np import torch from PIL import Image from basicsr.utils.download_util import load_file_from_url +from tqdm import tqdm from modules import modelloader from modules.paths import models_path @@ -122,18 +123,20 @@ def inference(img, model, tile, tile_overlap, window_size, scale): E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) W = torch.zeros_like(E, dtype=torch.half, device=device) - for h_idx in h_idx_list: - for w_idx in w_idx_list: - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) + with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: + for h_idx in h_idx_list: + for w_idx in w_idx_list: + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] + out_patch = model(in_patch) + out_patch_mask = torch.ones_like(out_patch) + + E[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch) + W[ + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf + ].add_(out_patch_mask) + pbar.update(1) output = E.div_(W) return output -- cgit v1.2.3 From afaa03c5fd05f48ed9c9f15558ea6f0bc4f61628 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 1 Oct 2022 22:43:45 +0100 Subject: add redefinition guard to gradio_routes_templates_response --- modules/ui.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index ec6aaa28..fd057916 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1219,12 +1219,13 @@ for filename in sorted(os.listdir(jsdir)): javascript += f"\n" -def template_response(*args, **kwargs): - res = gradio_routes_templates_response(*args, **kwargs) - res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res +if 'gradio_routes_templates_response' not in globals(): + def template_response(*args, **kwargs): + res = gradio_routes_templates_response(*args, **kwargs) + res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + gradio_routes_templates_response = gradio.routes.templates.TemplateResponse + gradio.routes.templates.TemplateResponse = template_response -gradio_routes_templates_response = gradio.routes.templates.TemplateResponse -gradio.routes.templates.TemplateResponse = template_response -- cgit v1.2.3 From 6048002dade91b82b1ce9fea3c6ff5b5c1f8c990 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 1 Oct 2022 23:10:07 +0100 Subject: Add scope warning to refresh button --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index fd057916..72846a12 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1005,7 +1005,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): def request_restart(): settings_interface.gradio_ref.do_restart = True - restart_gradio = gr.Button(value='Restart Gradio and Refresh Scripts') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') restart_gradio.click( fn=request_restart, inputs=[], -- cgit v1.2.3 From 027c5aae5546ff3650347cb3c2b87df4415ab900 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 1 Oct 2022 23:29:26 +0100 Subject: update reloading message style --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 72846a12..7b2359c2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1010,7 +1010,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): fn=request_restart, inputs=[], outputs=[], - _js='function(){document.body.innerHTML=\'

Reloading

\';setTimeout(function(){location.reload()},2000)}' + _js='function(){document.body.innerHTML=\'

Reloading...

\';setTimeout(function(){location.reload()},2000)}' ) if column is not None: -- cgit v1.2.3 From 0aa354bd5e811e2b41b17a3052cf5d4c8190d533 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 2 Oct 2022 00:13:47 +0100 Subject: remove styling from python side --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 7b2359c2..cb859ac4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1010,7 +1010,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): fn=request_restart, inputs=[], outputs=[], - _js='function(){document.body.innerHTML=\'

Reloading...

\';setTimeout(function(){location.reload()},2000)}' + _js='function(){restart_reload()}' ) if column is not None: -- cgit v1.2.3 From cf33268d686986a24f2e04eb615f01ed53bfe308 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 2 Oct 2022 01:18:42 +0100 Subject: add script body only refresh --- modules/scripts.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 3c14b9e3..788397f5 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -162,10 +162,33 @@ class ScriptRunner: return processed + def reload_sources(self): + for si,script in list(enumerate(self.scripts)): + with open(script.filename, "r", encoding="utf8") as file: + args_from = script.args_from + args_to = script.args_to + filename = script.filename + text = file.read() + + from types import ModuleType + compiled = compile(text, filename, 'exec') + module = ModuleType(script.filename) + exec(compiled, module.__dict__) + + for key, script_class in module.__dict__.items(): + if type(script_class) == type and issubclass(script_class, Script): + self.scripts[si] = script_class() + self.scripts[si].filename = filename + self.scripts[si].args_from = args_from + self.scripts[si].args_to = args_to scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() +def reload_script_body_only(): + scripts_txt2img.reload_sources() + scripts_img2img.reload_sources() + def reload_scripts(basedir): global scripts_txt2img,scripts_img2img -- cgit v1.2.3 From 07e40ad7f23472fc1c781fe1cc6c1ee403413918 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 2 Oct 2022 01:19:55 +0100 Subject: add custom script body only refresh option --- modules/ui.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index cb859ac4..eb7c0585 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1012,6 +1012,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): outputs=[], _js='function(){restart_reload()}' ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='primary') + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[], + _js='function(){}' + ) if column is not None: column.__exit__() -- cgit v1.2.3 From 2deea867814272f1f089b60e9ba8d587c16b2fb1 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 2 Oct 2022 01:36:30 +0100 Subject: Put reload buttons in row and add secondary style --- modules/ui.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index eb7c0585..963a2c61 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1002,27 +1002,30 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): _js='function(){}' ) - def request_restart(): - settings_interface.gradio_ref.do_restart = True + with gr.Row(): + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - restart_gradio.click( - fn=request_restart, - inputs=[], - outputs=[], - _js='function(){restart_reload()}' - ) def reload_scripts(): modules.scripts.reload_script_body_only() - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='primary') reload_script_bodies.click( fn=reload_scripts, inputs=[], outputs=[], _js='function(){}' ) + + def request_restart(): + settings_interface.gradio_ref.do_restart = True + + restart_gradio.click( + fn=request_restart, + inputs=[], + outputs=[], + _js='function(){restart_reload()}' + ) if column is not None: column.__exit__() -- cgit v1.2.3 From 3cf1a96006daffedb8ecd0ae142eca4c4da06105 Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Sat, 1 Oct 2022 21:11:03 -0700 Subject: added safety for blank directory naming patterns --- modules/images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index f1aed5d6..e7894b4c 100644 --- a/modules/images.py +++ b/modules/images.py @@ -311,7 +311,7 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[cfg]", str(p.cfg_scale)) x = x.replace("[width]", str(p.width)) x = x.replace("[height]", str(p.height)) - x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False)) + x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "No styles", replace_spaces=False)) x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) @@ -374,7 +374,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) if save_to_dirs: - dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt) + dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ ') path = os.path.join(path, dirname) os.makedirs(path, exist_ok=True) -- cgit v1.2.3 From 70f526704721a303ae045f6406439dcceee4302e Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Sat, 1 Oct 2022 21:18:15 -0700 Subject: use os.path.normpath for better safety checking --- modules/images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index e7894b4c..5ef7eb92 100644 --- a/modules/images.py +++ b/modules/images.py @@ -374,8 +374,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) if save_to_dirs: - dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ ') - path = os.path.join(path, dirname) + dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt) + path = os.path.normpath(os.path.join(path, dirname)) os.makedirs(path, exist_ok=True) -- cgit v1.2.3 From 32edf1732f27a1fad5133667c22b948adda1b070 Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Sat, 1 Oct 2022 21:37:14 -0700 Subject: os.path.normpath wasn't working, reverting to manual strip --- modules/images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 5ef7eb92..4998e92c 100644 --- a/modules/images.py +++ b/modules/images.py @@ -374,8 +374,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) if save_to_dirs: - dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt) - path = os.path.normpath(os.path.join(path, dirname)) + dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /') + path = os.path.join(path, dirname) os.makedirs(path, exist_ok=True) -- cgit v1.2.3 From 820f1dc96b1979d7e92170c161db281ee8bd988b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 15:03:39 +0300 Subject: initial support for training textual inversion --- modules/devices.py | 3 +- modules/processing.py | 13 +- modules/sd_hijack.py | 324 ++++--------------------- modules/sd_hijack_optimizations.py | 164 +++++++++++++ modules/sd_models.py | 4 +- modules/shared.py | 3 +- modules/textual_inversion/dataset.py | 76 ++++++ modules/textual_inversion/textual_inversion.py | 258 ++++++++++++++++++++ modules/textual_inversion/ui.py | 32 +++ modules/ui.py | 139 +++++++++-- 10 files changed, 717 insertions(+), 299 deletions(-) create mode 100644 modules/sd_hijack_optimizations.py create mode 100644 modules/textual_inversion/dataset.py create mode 100644 modules/textual_inversion/textual_inversion.py create mode 100644 modules/textual_inversion/ui.py (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 07bb2339..ff82f2f6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -32,10 +32,9 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") - device = get_optimal_device() device_codeformer = cpu if has_mps else device - +dtype = torch.float16 def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. diff --git a/modules/processing.py b/modules/processing.py index 7eeb5191..8223423a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -56,7 +56,7 @@ class StableDiffusionProcessing: self.prompt: str = prompt self.prompt_for_display: str = None self.negative_prompt: str = (negative_prompt or "") - self.styles: str = styles + self.styles: list = styles or [] self.seed: int = seed self.subseed: int = subseed self.subseed_strength: float = subseed_strength @@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), - "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta), + "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), } generation_params.update(p.extra_generation_params) @@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: fix_seed(p) - os.makedirs(p.outpath_samples, exist_ok=True) - os.makedirs(p.outpath_grids, exist_ok=True) + if p.outpath_samples is not None: + os.makedirs(p.outpath_samples, exist_ok=True) + + if p.outpath_grids is not None: + os.makedirs(p.outpath_grids, exist_ok=True) modules.sd_hijack.model_hijack.apply_circular(p.tiling) @@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) if os.path.exists(cmd_opts.embeddings_dir): - model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) + model_hijack.embedding_db.load_textual_inversion_embeddings() infotexts = [] output_images = [] diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fa7eaeb8..fd57e5c5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -6,244 +6,41 @@ import torch import numpy as np from torch import einsum -from modules import prompt_parser +import modules.textual_inversion.textual_inversion +from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules.shared import opts, device, cmd_opts -from ldm.util import default -from einops import rearrange import ldm.modules.attention import ldm.modules.diffusionmodules.model +attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward +diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity +diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward -# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None): - h = self.heads - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - del context, x +def apply_optimizations(): + if cmd_opts.opt_split_attention_v1: + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale +def undo_optimizations(): + ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward + ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward - s2 = s1.softmax(dim=-1) - del s1 - - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - - -# taken from https://github.com/Doggettx/stable-diffusion -def split_cross_attention_forward(self, x, context=None, mask=None): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - - del q, k, v - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - -def nonlinearity_hijack(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - -def cross_attention_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q1 = self.q(h_) - k1 = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q1.shape - - q2 = q1.reshape(b, c, h*w) - del q1 - - q = q2.permute(0, 2, 1) # b,hw,c - del q2 - - k = k1.reshape(b, c, h*w) # b,c,hw - del k1 - - h_ = torch.zeros_like(k, device=q.device) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - mem_required = tensor_size * 2.5 - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - - w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w2 = w1 * (int(c)**(-0.5)) - del w1 - w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) - del w2 - - # attend to values - v1 = v.reshape(b, c, h*w) - w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w3 - - h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w4 - - h2 = h_.reshape(b, c, h, w) - del h_ - - h3 = self.proj_out(h2) - del h2 - - h3 += x - - return h3 class StableDiffusionModelHijack: - ids_lookup = {} - word_embeddings = {} - word_embeddings_checksums = {} fixes = None comments = [] - dir_mtime = None layers = None circular_enabled = False clip = None - def load_textual_inversion_embeddings(self, dirname, model): - mt = os.path.getmtime(dirname) - if self.dir_mtime is not None and mt <= self.dir_mtime: - return - - self.dir_mtime = mt - self.ids_lookup.clear() - self.word_embeddings.clear() - - tokenizer = model.cond_stage_model.tokenizer - - def const_hash(a): - r = 0 - for v in a: - r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF - return r - - def process_file(path, filename): - name = os.path.splitext(filename)[0] - - data = torch.load(path, map_location="cpu") - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - self.word_embeddings[name] = emb.detach().to(device) - self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}' - - ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] - - first_id = ids[0] - if first_id not in self.ids_lookup: - self.ids_lookup[first_id] = [] - self.ids_lookup[first_id].append((ids, name)) - - for fn in os.listdir(dirname): - try: - fullfn = os.path.join(dirname, fn) - - if os.stat(fullfn).st_size == 0: - continue - - process_file(fullfn, fn) - except Exception: - print(f"Error loading emedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - continue - - print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): model_embeddings = m.cond_stage_model.transformer.text_model.embeddings @@ -253,12 +50,7 @@ class StableDiffusionModelHijack: self.clip = m.cond_stage_model - if cmd_opts.opt_split_attention_v1: - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack - ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + apply_optimizations() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.hijack = hijack + self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer self.max_length = wrapped.max_length self.token_mults = {} @@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id @@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - if possible_matches is None: + if embedding is None: remade_tokens.append(token) multipliers.append(weight) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i + len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: mult *= mult_change - elif possible_matches is None: + i += 1 + elif embedding is None: remade_tokens.append(token) multipliers.append(mult) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i+len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(mult) - - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] @@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - self.hijack.fixes = hijack_fixes self.hijack.comments = hijack_comments @@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module): inputs_embeds = self.wrapped(input_ids) - if batch_fixes is not None: - for fixes, tensor in zip(batch_fixes, inputs_embeds): - for offset, word in fixes: - emb = self.embeddings.word_embeddings[word] - emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) - tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len] + if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: + return inputs_embeds + + vecs = [] + for fixes, tensor in zip(batch_fixes, inputs_embeds): + for offset, embedding in fixes: + emb = embedding.vec + emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) + + vecs.append(tensor) - return inputs_embeds + return torch.stack(vecs) def add_circular_option_to_conv_2d(): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py new file mode 100644 index 00000000..9c079e57 --- /dev/null +++ b/modules/sd_hijack_optimizations.py @@ -0,0 +1,164 @@ +import math +import torch +from torch import einsum + +from ldm.util import default +from einops import rearrange + + +# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion +def split_cross_attention_forward_v1(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +# taken from https://github.com/Doggettx/stable-diffusion +def split_cross_attention_forward(self, x, context=None, mask=None): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) * self.scale + v_in = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + +def nonlinearity_hijack(x): + # swish + t = torch.sigmoid(x) + x *= t + del t + + return x + +def cross_attention_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_) + k1 = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 diff --git a/modules/sd_models.py b/modules/sd_models.py index 2539f14c..5b3dbdc7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader +from modules import shared, modelloader, devices from modules.paths import models_path model_dir = "Stable-diffusion" @@ -134,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): if not shared.cmd_opts.no_half: model.half() + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + model.sd_model_hash = sd_model_hash model.sd_model_checkpint = checkpoint_file diff --git a/modules/shared.py b/modules/shared.py index ac968b2d..ac0bc480 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -78,6 +78,7 @@ class State: current_latent = None current_image = None current_image_sampling_step = 0 + textinfo = None def interrupt(self): self.interrupted = True @@ -88,7 +89,7 @@ class State: self.current_image_sampling_step = 0 def get_job_timestamp(self): - return datetime.datetime.now().strftime("%Y%m%d%H%M%S") + return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? state = State() diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py new file mode 100644 index 00000000..7e134a08 --- /dev/null +++ b/modules/textual_inversion/dataset.py @@ -0,0 +1,76 @@ +import os +import numpy as np +import PIL +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random +import tqdm + + +class PersonalizedBase(Dataset): + def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None): + + self.placeholder_token = placeholder_token + + self.size = size + self.width = width + self.height = height + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + self.dataset = [] + + with open(template_file, "r") as file: + lines = [x.strip() for x in file.readlines()] + + self.lines = lines + + assert data_root, 'dataset directory not specified' + + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + print("Preparing dataset...") + for path in tqdm.tqdm(self.image_paths): + image = Image.open(path) + image = image.convert('RGB') + image = image.resize((self.width, self.height), PIL.Image.BICUBIC) + + filename = os.path.basename(path) + filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-') + filename_tokens = [token for token in filename_tokens if token.isalpha()] + + npimage = np.array(image).astype(np.uint8) + npimage = (npimage / 127.5 - 1.0).astype(np.float32) + + torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) + torchdata = torch.moveaxis(torchdata, 2, 0) + + init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() + + self.dataset.append((init_latent, filename_tokens)) + + self.length = len(self.dataset) * repeats + + self.initial_indexes = np.arange(self.length) % len(self.dataset) + self.indexes = None + self.shuffle() + + def shuffle(self): + self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + + def __len__(self): + return self.length + + def __getitem__(self, i): + if i % len(self.dataset) == 0: + self.shuffle() + + index = self.indexes[i % len(self.indexes)] + x, filename_tokens = self.dataset[index] + + text = random.choice(self.lines) + text = text.replace("[name]", self.placeholder_token) + text = text.replace("[filewords]", ' '.join(filename_tokens)) + + return x, text diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py new file mode 100644 index 00000000..c0baaace --- /dev/null +++ b/modules/textual_inversion/textual_inversion.py @@ -0,0 +1,258 @@ +import os +import sys +import traceback + +import torch +import tqdm +import html +import datetime + +from modules import shared, devices, sd_hijack, processing +import modules.textual_inversion.dataset + + +class Embedding: + def __init__(self, vec, name, step=None): + self.vec = vec + self.name = name + self.step = step + self.cached_checksum = None + + def save(self, filename): + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": self.step, + } + + torch.save(embedding_data, filename) + + def checksum(self): + if self.cached_checksum is not None: + return self.cached_checksum + + def const_hash(a): + r = 0 + for v in a: + r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF + return r + + self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' + return self.cached_checksum + +class EmbeddingDatabase: + def __init__(self, embeddings_dir): + self.ids_lookup = {} + self.word_embeddings = {} + self.dir_mtime = None + self.embeddings_dir = embeddings_dir + + def register_embedding(self, embedding, model): + + self.word_embeddings[embedding.name] = embedding + + ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0] + + first_id = ids[0] + if first_id not in self.ids_lookup: + self.ids_lookup[first_id] = [] + self.ids_lookup[first_id].append((ids, embedding)) + + return embedding + + def load_textual_inversion_embeddings(self): + mt = os.path.getmtime(self.embeddings_dir) + if self.dir_mtime is not None and mt <= self.dir_mtime: + return + + self.dir_mtime = mt + self.ids_lookup.clear() + self.word_embeddings.clear() + + def process_file(path, filename): + name = os.path.splitext(filename)[0] + + data = torch.load(path, map_location="cpu") + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + self.register_embedding(embedding, shared.sd_model) + + for fn in os.listdir(self.embeddings_dir): + try: + fullfn = os.path.join(self.embeddings_dir, fn) + + if os.stat(fullfn).st_size == 0: + continue + + process_file(fullfn, fn) + except Exception: + print(f"Error loading emedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + continue + + print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + + def find_embedding_at_position(self, tokens, offset): + token = tokens[offset] + possible_matches = self.ids_lookup.get(token, None) + + if possible_matches is None: + return None + + for ids, embedding in possible_matches: + if tokens[offset:offset + len(ids)] == ids: + return embedding + + return None + + + +def create_embedding(name, num_vectors_per_token): + init_text = '*' + + cond_model = shared.sd_model.cond_stage_model + embedding_layer = cond_model.wrapped.transformer.text_model.embeddings + + ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer(ids.to(devices.device)).squeeze(0) + vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) + + for i in range(num_vectors_per_token): + vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + + fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + embedding = Embedding(vec, name) + embedding.step = 0 + embedding.save(fn) + + return fn + + +def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): + assert embedding_name, 'embedding not selected' + + shared.state.textinfo = "Initializing textual inversion training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name) + + if save_embedding_every > 0: + embedding_dir = os.path.join(log_directory, "embeddings") + os.makedirs(embedding_dir, exist_ok=True) + else: + embedding_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + cond_model = shared.sd_model.cond_stage_model + + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + + hijack = sd_hijack.model_hijack + + embedding = hijack.embedding_db.word_embeddings[embedding_name] + embedding.vec.requires_grad = True + + optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) + + losses = torch.zeros((32,)) + + last_saved_file = "" + last_saved_image = "" + + ititial_step = embedding.step or 0 + if ititial_step > steps: + return embedding, filename + + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + for i, (x, text) in pbar: + embedding.step = i + ititial_step + + if embedding.step > steps: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([text]) + loss = shared.sd_model(x.unsqueeze(0), c)[0] + + losses[embedding.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description(f"loss: {losses.mean():.7f}") + + if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: + last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') + embedding.save(last_saved_file) + + if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: + last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + prompt=text, + steps=20, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + processed = processing.process_images(p) + image = processed.images[0] + + shared.state.current_image = image + image.save(last_saved_image) + + last_saved_image += f", prompt: {text}" + + shared.state.job_no = embedding.step + + shared.state.textinfo = f""" +

+Loss: {losses.mean():.7f}
+Step: {embedding.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + + embedding.cached_checksum = None + embedding.save(filename) + + return embedding, filename + diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py new file mode 100644 index 00000000..ce3677a9 --- /dev/null +++ b/modules/textual_inversion/ui.py @@ -0,0 +1,32 @@ +import html + +import gradio as gr + +import modules.textual_inversion.textual_inversion as ti +from modules import sd_hijack, shared + + +def create_embedding(name, nvpt): + filename = ti.create_embedding(name, nvpt) + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" + + +def train_embedding(*args): + + try: + sd_hijack.undo_optimizations() + + embedding, filename = ti.train_embedding(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps. +Embedding saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + sd_hijack.apply_optimizations() diff --git a/modules/ui.py b/modules/ui.py index 15572bb0..57aef6ff 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,6 +21,7 @@ import gradio as gr import gradio.utils import gradio.routes +from modules import sd_hijack from modules.paths import script_path from modules.shared import opts, cmd_opts import modules.shared as shared @@ -32,6 +33,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste +import modules.textual_inversion.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() @@ -142,8 +144,8 @@ def save_files(js_data, images, index): return '', '', plaintext_to_html(f"Saved: {filenames[0]}") -def wrap_gradio_call(func): - def f(*args, **kwargs): +def wrap_gradio_call(func, extra_outputs=None): + def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled if run_memmon: shared.mem_mon.monitor() @@ -159,7 +161,10 @@ def wrap_gradio_call(func): shared.state.job = "" shared.state.job_count = 0 - res = [None, '', f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + if extra_outputs_array is None: + extra_outputs_array = [None, ''] + + res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] elapsed = time.perf_counter() - t @@ -179,6 +184,7 @@ def wrap_gradio_call(func): res[-1] += f"

Time taken: {elapsed:.2f}s

{vram_html}
" shared.state.interrupted = False + shared.state.job_count = 0 return tuple(res) @@ -187,7 +193,7 @@ def wrap_gradio_call(func): def check_progress_call(id_part): if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False) + return "", gr_show(False), gr_show(False), gr_show(False) progress = 0 @@ -219,13 +225,19 @@ def check_progress_call(id_part): else: preview_visibility = gr_show(True) - return f"

{progressbar}

", preview_visibility, image + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result def check_progress_call_initial(id_part): shared.state.job_count = -1 shared.state.current_latent = None shared.state.current_image = None + shared.state.textinfo = None return check_progress_call(id_part) @@ -399,13 +411,16 @@ def create_toprow(is_img2img): return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste -def setup_progressbar(progressbar, preview, id_part): +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) check_progress.click( fn=lambda: check_progress_call(id_part), show_progress=False, inputs=[], - outputs=[progressbar, preview, preview], + outputs=[progressbar, preview, preview, textinfo], ) check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) @@ -413,11 +428,14 @@ def setup_progressbar(progressbar, preview, id_part): fn=lambda: check_progress_call_initial(id_part), show_progress=False, inputs=[], - outputs=[progressbar, preview, preview], + outputs=[progressbar, preview, preview, textinfo], ) -def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): +def create_ui(wrap_gradio_gpu_call): + import modules.img2img + import modules.txt2img + with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) @@ -483,7 +501,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) txt2img_args = dict( - fn=txt2img, + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), _js="submit", inputs=[ txt2img_prompt, @@ -675,7 +693,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ) img2img_args = dict( - fn=img2img, + fn=wrap_gradio_gpu_call(modules.img2img.img2img), _js="submit_img2img", inputs=[ dummy_component, @@ -828,7 +846,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): open_extras_folder = gr.Button('Open output directory', elem_id=button_id) submit.click( - fn=run_extras, + fn=wrap_gradio_gpu_call(modules.extras.run_extras), _js="get_extras_tab_index", inputs=[ dummy_component, @@ -878,7 +896,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): pnginfo_send_to_img2img = gr.Button('Send to img2img') image.change( - fn=wrap_gradio_call(run_pnginfo), + fn=wrap_gradio_call(modules.extras.run_pnginfo), inputs=[image], outputs=[html, generation_info, html2], ) @@ -887,7 +905,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - + with gr.Row(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") @@ -896,10 +914,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Safe as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - + with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + with gr.Blocks() as textual_inversion_interface: + with gr.Row().style(equal_height=False): + with gr.Column(): + with gr.Group(): + gr.HTML(value="

Create a new embedding

") + + new_embedding_name = gr.Textbox(label="Name") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create", variant='primary') + + with gr.Group(): + gr.HTML(value="

Train an embedding; must specify a directory with a set of 512x512 images

") + train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + learn_rate = gr.Number(label='Learning rate', value=5.0e-03) + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") + template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + steps = gr.Number(label='Max steps', value=100000, precision=0) + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0) + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0) + + with gr.Row(): + with gr.Column(scale=2): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_training = gr.Button(value="Interrupt") + train_embedding = gr.Button(value="Train", variant='primary') + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + nvpt, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + learn_rate, + dataset_directory, + log_directory, + steps, + create_image_every, + save_embedding_every, + template_file, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + def create_setting_component(key): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -1011,6 +1115,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (textual_inversion_interface, "Textual inversion", "ti"), (settings_interface, "Settings", "settings"), ] @@ -1044,11 +1149,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): def modelmerger(*args): try: - results = run_modelmerger(*args) + results = modules.extras.run_modelmerger(*args) except Exception as e: print("Error loading/saving model file:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() #To remove the potentially missing models from the list + modules.sd_models.list_models() # to remove the potentially missing models from the list return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] return results -- cgit v1.2.3 From 0114057ad672a581bd0b598870b58b674b1a3624 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 15:49:42 +0300 Subject: fix incorrect use of glob in modelloader for #1410 --- modules/modelloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/modelloader.py b/modules/modelloader.py index 8c862b42..015aeafa 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -43,7 +43,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None for place in places: if os.path.exists(place): for file in glob.iglob(place + '**/**', recursive=True): - full_path = os.path.join(place, file) + full_path = file if os.path.isdir(full_path): continue if len(ext_filter) != 0: -- cgit v1.2.3 From 0758f6e641b5790ce566a998d43e0ea74a627766 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 17:24:50 +0300 Subject: fix --ckpt option breaking model selection --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 5b3dbdc7..9259d69e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -69,7 +69,7 @@ def list_models(): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) - shared.opts.sd_model_checkpoint = title + shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) for filename in model_list: -- cgit v1.2.3 From 88ec0cf5571883d84abd09196652b3679e359f2e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 19:40:51 +0300 Subject: fix for incorrect embedding token length calculation (will break seeds that use embeddings, you're welcome!) add option to input initialization text for embeddings --- modules/sd_hijack.py | 8 ++++---- modules/textual_inversion/textual_inversion.py | 13 +++++-------- modules/textual_inversion/ui.py | 4 ++-- modules/ui.py | 2 ++ 4 files changed, 13 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fd57e5c5..3fa06242 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) if embedding is None: remade_tokens.append(token) @@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_tokens += [0] * emb_len multipliers += [weight] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) - i += emb_len + i += embedding_length_in_tokens if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: @@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_tokens += [0] * emb_len multipliers += [mult] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) - i += emb_len + i += embedding_length_in_tokens if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c0baaace..0c50161d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -117,24 +117,21 @@ class EmbeddingDatabase: possible_matches = self.ids_lookup.get(token, None) if possible_matches is None: - return None + return None, None for ids, embedding in possible_matches: if tokens[offset:offset + len(ids)] == ids: - return embedding + return embedding, len(ids) - return None + return None, None - -def create_embedding(name, num_vectors_per_token): - init_text = '*' - +def create_embedding(name, num_vectors_per_token, init_text='*'): cond_model = shared.sd_model.cond_stage_model embedding_layer = cond_model.wrapped.transformer.text_model.embeddings ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] - embedded = embedding_layer(ids.to(devices.device)).squeeze(0) + embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) for i in range(num_vectors_per_token): diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index ce3677a9..66c43ffb 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti from modules import sd_hijack, shared -def create_embedding(name, nvpt): - filename = ti.create_embedding(name, nvpt) +def create_embedding(name, initialization_text, nvpt): + filename = ti.create_embedding(name, nvpt, init_text=initialization_text) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() diff --git a/modules/ui.py b/modules/ui.py index 3b81a4f7..eca50df0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="

Create a new embedding

") new_embedding_name = gr.Textbox(label="Name") + initialization_text = gr.Textbox(label="Initialization text", value="*") nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) with gr.Row(): @@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call): fn=modules.textual_inversion.ui.create_embedding, inputs=[ new_embedding_name, + initialization_text, nvpt, ], outputs=[ -- cgit v1.2.3 From 71fe7fa49f5eb1a2c89932a9d217ed153c12fc8b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 19:56:37 +0300 Subject: fix using aaaa-100 embedding when the prompt has aaaa-10000 and you have both aaaa-100 and aaaa-10000 in the directory with embeddings. --- modules/textual_inversion/textual_inversion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 0c50161d..9d2241ce 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -57,7 +57,8 @@ class EmbeddingDatabase: first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] - self.ids_lookup[first_id].append((ids, embedding)) + + self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) return embedding -- cgit v1.2.3 From 4ec4af6e0b7addeee5221a03f32d117ccdc875d9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 20:15:25 +0300 Subject: add checkpoint info to saved embeddings --- modules/textual_inversion/textual_inversion.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 9d2241ce..1183aab7 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,7 +7,7 @@ import tqdm import html import datetime -from modules import shared, devices, sd_hijack, processing +from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset @@ -17,6 +17,8 @@ class Embedding: self.name = name self.step = step self.cached_checksum = None + self.sd_checkpoint = None + self.sd_checkpoint_name = None def save(self, filename): embedding_data = { @@ -24,6 +26,8 @@ class Embedding: "string_to_param": {"*": self.vec}, "name": self.name, "step": self.step, + "sd_checkpoint": self.sd_checkpoint, + "sd_checkpoint_name": self.sd_checkpoint_name, } torch.save(embedding_data, filename) @@ -41,6 +45,7 @@ class Embedding: self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' return self.cached_checksum + class EmbeddingDatabase: def __init__(self, embeddings_dir): self.ids_lookup = {} @@ -96,6 +101,8 @@ class EmbeddingDatabase: vec = emb.detach().to(devices.device, dtype=torch.float32) embedding = Embedding(vec, name) embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('hash', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) self.register_embedding(embedding, shared.sd_model) for fn in os.listdir(self.embeddings_dir): @@ -249,6 +256,10 @@ Last saved image: {html.escape(last_saved_image)}

""" + checkpoint = sd_models.select_checkpoint() + + embedding.sd_checkpoint = checkpoint.hash + embedding.sd_checkpoint_name = checkpoint.model_name embedding.cached_checksum = None embedding.save(filename) -- cgit v1.2.3 From 3ff0de2c594b786ef948a89efb1814c59bb42117 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 20:23:40 +0300 Subject: added --disable-console-progressbars to disable progressbars in console disabled printing prompts to console by default, enabled by --enable-console-prompts --- modules/img2img.py | 4 +++- modules/sd_samplers.py | 8 ++++++-- modules/shared.py | 7 +++++-- modules/txt2img.py | 4 +++- 4 files changed, 17 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 03e934e9..f4455c90 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -103,7 +103,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpaint_full_res_padding=inpaint_full_res_padding, inpainting_mask_invert=inpainting_mask_invert, ) - print(f"\nimg2img: {prompt}", file=shared.progress_print_out) + + if shared.cmd_opts.enable_console_prompts: + print(f"\nimg2img: {prompt}", file=shared.progress_print_out) p.extra_generation_params["Mask blur"] = mask_blur diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 92522214..9316875a 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -77,7 +77,9 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs): state.sampling_steps = len(sequence) state.sampling_step = 0 - for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs): + seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs) + + for x in seq: if state.interrupted: break @@ -207,7 +209,9 @@ def extended_trange(sampler, count, *args, **kwargs): state.sampling_steps = count state.sampling_step = 0 - for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs): + seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs) + + for x in seq: if state.interrupted: break diff --git a/modules/shared.py b/modules/shared.py index 5a591dc9..1bf7a6c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -58,6 +58,9 @@ parser.add_argument("--opt-channelslast", action='store_true', help="change memo parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) +parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) +parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) + cmd_opts = parser.parse_args() device = get_optimal_device() @@ -320,14 +323,14 @@ class TotalTQDM: ) def update(self): - if not opts.multiple_tqdm: + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: return if self._tqdm is None: self.reset() self._tqdm.update() def updateTotal(self, new_total): - if not opts.multiple_tqdm: + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: return if self._tqdm is None: self.reset() diff --git a/modules/txt2img.py b/modules/txt2img.py index 5368e4d0..d4406c3c 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -34,7 +34,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: denoising_strength=denoising_strength if enable_hr else None, ) - print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) + if cmd_opts.enable_console_prompts: + print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) + processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: -- cgit v1.2.3 From 6365a41f5981efa506dfe4e8fa878b43ca2d8d0c Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Sun, 2 Oct 2022 12:58:17 -0500 Subject: Update esrgan_model.py Use alternate ESRGAN Model download path. --- modules/esrgan_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index ea91abfe..4aed9283 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -73,8 +73,8 @@ def fix_model_layers(crt_model, pretrained_net): class UpscalerESRGAN(Upscaler): def __init__(self, dirname): self.name = "ESRGAN" - self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download" - self.model_name = "ESRGAN 4x" + self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth" + self.model_name = "ESRGAN_4x" self.scalers = [] self.user_path = dirname self.model_path = os.path.join(models_path, self.name) -- cgit v1.2.3 From a1cde7e6468f80584030525a1b07cbf0f4ee42eb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 21:09:10 +0300 Subject: disabled SD model download after multiple complaints --- modules/sd_models.py | 18 ++++++++---------- modules/textual_inversion/ui.py | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 9259d69e..9a6b568f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -13,9 +13,6 @@ from modules.paths import models_path model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -model_name = "sd-v1-4.ckpt" -model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1" -user_dir = None CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} @@ -30,12 +27,10 @@ except Exception: pass -def setup_model(dirname): - global user_dir - user_dir = dirname +def setup_model(): if not os.path.exists(model_path): os.makedirs(model_path) - checkpoints_list.clear() + list_models() @@ -45,7 +40,7 @@ def checkpoint_tiles(): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -106,8 +101,11 @@ def select_checkpoint(): if len(checkpoints_list) == 0: print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) - print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) - print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) + if shared.cmd_opts.ckpt is not None: + print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) + print(f" - directory {model_path}", file=sys.stderr) + if shared.cmd_opts.ckpt_dir is not None: + print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) exit(1) diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 66c43ffb..633037d8 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -22,7 +22,7 @@ def train_embedding(*args): embedding, filename = ti.train_embedding(*args) res = f""" -Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps. +Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. Embedding saved to {html.escape(filename)} """ return res, "" -- cgit v1.2.3 From 852fd90c0dcda9cb5fbbfdf0c7308ce58034935c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 21:22:20 +0300 Subject: emergency fix for disabling SD model download after multiple complaints --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 9a6b568f..5f992064 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -45,8 +45,8 @@ def list_models(): def modeltitle(path, shorthash): abspath = os.path.abspath(path) - if user_dir is not None and abspath.startswith(user_dir): - name = abspath.replace(user_dir, '') + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): + name = abspath.replace(shared.cmd_opts.ckpt_dir, '') elif abspath.startswith(model_path): name = abspath.replace(model_path, '') else: -- cgit v1.2.3 From e808096cf641d868f88465515d70d40fc46125d4 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 2 Oct 2022 19:26:06 +0100 Subject: correct indent --- modules/scripts.py | 48 +++++++++++++++++++++++++----------------------- modules/ui.py | 25 ++++++++++++------------- 2 files changed, 37 insertions(+), 36 deletions(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 788397f5..45230f9a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -163,37 +163,39 @@ class ScriptRunner: return processed def reload_sources(self): - for si,script in list(enumerate(self.scripts)): - with open(script.filename, "r", encoding="utf8") as file: - args_from = script.args_from - args_to = script.args_to - filename = script.filename - text = file.read() + for si, script in list(enumerate(self.scripts)): + with open(script.filename, "r", encoding="utf8") as file: + args_from = script.args_from + args_to = script.args_to + filename = script.filename + text = file.read() - from types import ModuleType - compiled = compile(text, filename, 'exec') - module = ModuleType(script.filename) - exec(compiled, module.__dict__) + from types import ModuleType - for key, script_class in module.__dict__.items(): - if type(script_class) == type and issubclass(script_class, Script): - self.scripts[si] = script_class() - self.scripts[si].filename = filename - self.scripts[si].args_from = args_from - self.scripts[si].args_to = args_to + compiled = compile(text, filename, 'exec') + module = ModuleType(script.filename) + exec(compiled, module.__dict__) + + for key, script_class in module.__dict__.items(): + if type(script_class) == type and issubclass(script_class, Script): + self.scripts[si] = script_class() + self.scripts[si].filename = filename + self.scripts[si].args_from = args_from + self.scripts[si].args_to = args_to scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() def reload_script_body_only(): - scripts_txt2img.reload_sources() - scripts_img2img.reload_sources() + scripts_txt2img.reload_sources() + scripts_img2img.reload_sources() + def reload_scripts(basedir): - global scripts_txt2img,scripts_img2img + global scripts_txt2img, scripts_img2img - scripts_data.clear() - load_scripts(basedir) + scripts_data.clear() + load_scripts(basedir) - scripts_txt2img = ScriptRunner() - scripts_img2img = ScriptRunner() + scripts_txt2img = ScriptRunner() + scripts_img2img = ScriptRunner() diff --git a/modules/ui.py b/modules/ui.py index 963a2c61..6b30f84b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1003,12 +1003,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ) with gr.Row(): - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') def reload_scripts(): - modules.scripts.reload_script_body_only() + modules.scripts.reload_script_body_only() reload_script_bodies.click( fn=reload_scripts, @@ -1018,7 +1018,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ) def request_restart(): - settings_interface.gradio_ref.do_restart = True + settings_interface.gradio_ref.do_restart = True restart_gradio.click( fn=request_restart, @@ -1234,12 +1234,11 @@ for filename in sorted(os.listdir(jsdir)): if 'gradio_routes_templates_response' not in globals(): - def template_response(*args, **kwargs): - res = gradio_routes_templates_response(*args, **kwargs) - res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio_routes_templates_response = gradio.routes.templates.TemplateResponse - gradio.routes.templates.TemplateResponse = template_response - + def template_response(*args, **kwargs): + res = gradio_routes_templates_response(*args, **kwargs) + res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gradio_routes_templates_response = gradio.routes.templates.TemplateResponse + gradio.routes.templates.TemplateResponse = template_response -- cgit v1.2.3 From 91f327f22bb2feb780c424c74723cc0629dc34a1 Mon Sep 17 00:00:00 2001 From: Lopyter Date: Sun, 2 Oct 2022 18:15:31 +0200 Subject: make save to dirs optional for imgs saved from ui --- modules/shared.py | 1 + modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 1bf7a6c1..785e7af6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -173,6 +173,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo "grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"), "directories_filename_pattern": OptionInfo("", "Directory name pattern"), "directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), + "use_save_to_dirs_for_ui": OptionInfo(False, "Use \"Save images to a subdirectory\" option for images saved from UI"), })) options_templates.update(options_section(('upscaling', "Upscaling"), { diff --git a/modules/ui.py b/modules/ui.py index 78a15d83..8912deff 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -113,7 +113,7 @@ def save_files(js_data, images, index): p = MyObject(data) path = opts.outdir_save - save_to_dirs = opts.save_to_dirs + save_to_dirs = opts.use_save_to_dirs_for_ui if save_to_dirs: dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, p.seed, p.prompt) -- cgit v1.2.3 From c4445225f79f1c57afe52358ff4b205864eb7aac Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 21:50:14 +0300 Subject: change wording for options --- modules/shared.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 785e7af6..7246eadc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -170,10 +170,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), { options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"), - "grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"), + "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"), + "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), "directories_filename_pattern": OptionInfo("", "Directory name pattern"), - "directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), - "use_save_to_dirs_for_ui": OptionInfo(False, "Use \"Save images to a subdirectory\" option for images saved from UI"), + "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), })) options_templates.update(options_section(('upscaling', "Upscaling"), { -- cgit v1.2.3 From c7543d4940da672d970124ae8f2fec9de7bdc1da Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 22:41:21 +0300 Subject: preprocessing for textual inversion added --- modules/interrogate.py | 1 + modules/textual_inversion/preprocess.py | 75 ++++++++++++++++++++++++++ modules/textual_inversion/textual_inversion.py | 1 + modules/textual_inversion/ui.py | 14 +++-- modules/ui.py | 36 +++++++++++++ 5 files changed, 124 insertions(+), 3 deletions(-) create mode 100644 modules/textual_inversion/preprocess.py (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index f62a4745..eed87144 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") + class InterrogateModels: blip_model = None clip_model = None diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py new file mode 100644 index 00000000..209e928f --- /dev/null +++ b/modules/textual_inversion/preprocess.py @@ -0,0 +1,75 @@ +import os +from PIL import Image, ImageOps +import tqdm + +from modules import shared, images + + +def preprocess(process_src, process_dst, process_flip, process_split, process_caption): + size = 512 + src = os.path.abspath(process_src) + dst = os.path.abspath(process_dst) + + assert src != dst, 'same directory specified as source and desitnation' + + os.makedirs(dst, exist_ok=True) + + files = os.listdir(src) + + shared.state.textinfo = "Preprocessing..." + shared.state.job_count = len(files) + + if process_caption: + shared.interrogator.load() + + def save_pic_with_caption(image, index): + if process_caption: + caption = "-" + shared.interrogator.generate_caption(image) + else: + caption = "" + + image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png")) + subindex[0] += 1 + + def save_pic(image, index): + save_pic_with_caption(image, index) + + if process_flip: + save_pic_with_caption(ImageOps.mirror(image), index) + + for index, imagefile in enumerate(tqdm.tqdm(files)): + subindex = [0] + filename = os.path.join(src, imagefile) + img = Image.open(filename).convert("RGB") + + if shared.state.interrupted: + break + + ratio = img.height / img.width + is_tall = ratio > 1.35 + is_wide = ratio < 1 / 1.35 + + if process_split and is_tall: + img = img.resize((size, size * img.height // img.width)) + + top = img.crop((0, 0, size, size)) + save_pic(top, index) + + bot = img.crop((0, img.height - size, size, img.height)) + save_pic(bot, index) + elif process_split and is_wide: + img = img.resize((size * img.width // img.height, size)) + + left = img.crop((0, 0, size, size)) + save_pic(left, index) + + right = img.crop((img.width - size, 0, img.width, size)) + save_pic(right, index) + else: + img = images.resize_image(1, img, size, size) + save_pic(img, index) + + shared.state.nextjob() + + if process_caption: + shared.interrogator.send_blip_to_ram() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 1183aab7..d4e250d8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,6 +7,7 @@ import tqdm import html import datetime + from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 633037d8..f19ac5e0 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -2,24 +2,31 @@ import html import gradio as gr -import modules.textual_inversion.textual_inversion as ti +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess from modules import sd_hijack, shared def create_embedding(name, initialization_text, nvpt): - filename = ti.create_embedding(name, nvpt, init_text=initialization_text) + filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" +def preprocess(*args): + modules.textual_inversion.preprocess.preprocess(*args) + + return "Preprocessing finished.", "" + + def train_embedding(*args): try: sd_hijack.undo_optimizations() - embedding, filename = ti.train_embedding(*args) + embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) res = f""" Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. @@ -30,3 +37,4 @@ Embedding saved to {html.escape(filename)} raise finally: sd_hijack.apply_optimizations() + diff --git a/modules/ui.py b/modules/ui.py index 8912deff..e7bde53b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -961,6 +961,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row().style(equal_height=False): with gr.Column(): with gr.Group(): + gr.HTML(value="

See wiki for detailed explanation.

") + gr.HTML(value="

Create a new embedding

") new_embedding_name = gr.Textbox(label="Name") @@ -974,6 +976,24 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create", variant='primary') + with gr.Group(): + gr.HTML(value="

Preprocess images

") + + process_src = gr.Textbox(label='Source directory') + process_dst = gr.Textbox(label='Destination directory') + + with gr.Row(): + process_flip = gr.Checkbox(label='Flip') + process_split = gr.Checkbox(label='Split into two') + process_caption = gr.Checkbox(label='Add caption') + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + run_preprocess = gr.Button(value="Preprocess", variant='primary') + with gr.Group(): gr.HTML(value="

Train an embedding; must specify a directory with a set of 512x512 images

") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) @@ -1018,6 +1038,22 @@ def create_ui(wrap_gradio_gpu_call): ] ) + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_flip, + process_split, + process_caption, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + train_embedding.click( fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", -- cgit v1.2.3 From 6785331e22d6a488fbf5905fab56d7fec867e038 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 22:59:01 +0300 Subject: keep textual inversion dataset latents in CPU memory to save a bit of VRAM --- modules/textual_inversion/dataset.py | 2 ++ modules/textual_inversion/textual_inversion.py | 3 +++ modules/ui.py | 4 ++-- 3 files changed, 7 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 7e134a08..e8394ff6 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -8,6 +8,7 @@ from torchvision import transforms import random import tqdm +from modules import devices class PersonalizedBase(Dataset): @@ -47,6 +48,7 @@ class PersonalizedBase(Dataset): torchdata = torch.moveaxis(torchdata, 2, 0) init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() + init_latent = init_latent.to(devices.cpu) self.dataset.append((init_latent, filename_tokens)) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d4e250d8..8686f534 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -212,7 +212,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, with torch.autocast("cuda"): c = cond_model([text]) + + x = x.to(devices.device) loss = shared.sd_model(x.unsqueeze(0), c)[0] + del x losses[embedding.step % losses.shape[0]] = loss.item() diff --git a/modules/ui.py b/modules/ui.py index e7bde53b..d9d02ece 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1002,8 +1002,8 @@ def create_ui(wrap_gradio_gpu_call): log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) steps = gr.Number(label='Max steps', value=100000, precision=0) - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0) - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0) + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) with gr.Row(): with gr.Column(scale=2): -- cgit v1.2.3 From 166283653cfe7521a422c91e8fb801f3ecb4adc8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 23:18:13 +0300 Subject: remove LDSR warning --- modules/paths.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/paths.py b/modules/paths.py index ceb80417..606f7d66 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,7 +20,6 @@ path_dirs = [ (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), - (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), ] -- cgit v1.2.3 From 138662734c25dab4e73e632b7eaff9ad9c0ce2b4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 3 Oct 2022 07:57:59 +0300 Subject: use dropdown instead of radio for img2img upscaler selection --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 7246eadc..2a599e9c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -183,7 +183,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), - "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), + "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { -- cgit v1.2.3 From e615d4f9d101e2712c7c2d0e3e8feb19cb430c74 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 2 Oct 2022 21:08:23 +0200 Subject: Convert folder icon surrogate pair to valid utf8 --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d9d02ece..16432151 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -69,7 +69,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️ reuse_symbol = '\u267b\ufe0f' # ♻️ art_symbol = '\U0001f3a8' # 🎨 paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\uD83D\uDCC2' +folder_symbol = '\U0001f4c2' # 📂 def plaintext_to_html(text): text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" -- cgit v1.2.3 From 34c638142eaa57f89b86545ba3c72085036398bb Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Fri, 30 Sep 2022 22:38:14 +0100 Subject: Fixed when eta = 0 Unexpected behavior when using eta = 0 in something like XY, but your default eta was set to something not 0. --- modules/sd_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 9316875a..dbf570d2 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -127,7 +127,7 @@ class VanillaStableDiffusionSampler: return res def initialize(self, p): - self.eta = p.eta or opts.eta_ddim + self.eta = p.eta if p.eta is not None else opts.eta_ddim for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): -- cgit v1.2.3 From 36ea4ac0f5844e5c8dec124edbdb714ccdd6013c Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Sun, 2 Oct 2022 22:21:16 -0700 Subject: moved no-style return outside join function --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index bba55158..1a046aca 100644 --- a/modules/images.py +++ b/modules/images.py @@ -315,7 +315,7 @@ def apply_filename_pattern(x, p, seed, prompt): #currently disabled if using the save button, will work otherwise # if enabled it will cause a bug because styles is not included in the save_files data dictionary if hasattr(p, "styles"): - x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"] or "None"), replace_spaces=False)) + x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False)) x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) -- cgit v1.2.3 From 6491b09c24ea77f1f69990ea80a216f9ce319589 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 3 Oct 2022 08:53:52 +0300 Subject: use existing function for gfpgan --- modules/gfpgan_model.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index bb30d733..dd3fbcab 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -97,11 +97,7 @@ def setup_model(dirname): return "GFPGAN" def restore(self, np_image): - np_image_bgr = np_image[:, :, ::-1] - cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) - np_image = gfpgan_output_bgr[:, :, ::-1] - - return np_image + return gfpgan_fix_faces(np_image) shared.face_restorers.append(FaceRestorerGFPGAN()) except Exception: -- cgit v1.2.3 From 43a74fa595003321200a40bd2431e56c245e75ed Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 3 Oct 2022 11:48:19 +0300 Subject: batch processing for img2img with an empty output directory, by request --- modules/img2img.py | 7 +++++-- modules/ui.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index f4455c90..2ff8e261 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -23,8 +23,10 @@ def process_batch(p, input_dir, output_dir, args): print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") + save_normally = output_dir == '' + p.do_not_save_grid = True - p.do_not_save_samples = True + p.do_not_save_samples = not save_normally state.job_count = len(images) * p.n_iter @@ -48,7 +50,8 @@ def process_batch(p, input_dir, output_dir, args): left, right = os.path.splitext(filename) filename = f"{left}-{n}{right}" - processed_image.save(os.path.join(output_dir, filename)) + if not save_normally: + processed_image.save(os.path.join(output_dir, filename)) def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): diff --git a/modules/ui.py b/modules/ui.py index 16432151..55f7aa95 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -658,7 +658,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.TabItem('Batch img2img', id='batch'): hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.{hidden}

") + gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) -- cgit v1.2.3 From 2865ef4b9ab16d56326cc805541bebcf01d099bc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 3 Oct 2022 13:10:03 +0300 Subject: fix broken date in TI --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8686f534..cd9f3498 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -164,7 +164,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') - log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name) + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) if save_embedding_every > 0: embedding_dir = os.path.join(log_directory, "embeddings") -- cgit v1.2.3 From 5ef0baf5eaec7f21a1666af424405cbee19f3764 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 08:52:11 +0300 Subject: add support for gelbooru tags in filenames for textual inversion --- modules/textual_inversion/dataset.py | 7 +++++-- modules/textual_inversion/preprocess.py | 4 +++- 2 files changed, 8 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index e8394ff6..7c44ea5b 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -9,6 +9,9 @@ from torchvision import transforms import random import tqdm from modules import devices +import re + +re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): @@ -38,8 +41,8 @@ class PersonalizedBase(Dataset): image = image.resize((self.width, self.height), PIL.Image.BICUBIC) filename = os.path.basename(path) - filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-') - filename_tokens = [token for token in filename_tokens if token.isalpha()] + filename_tokens = os.path.splitext(filename)[0] + filename_tokens = re_tag.findall(filename_tokens) npimage = np.array(image).astype(np.uint8) npimage = (npimage / 127.5 - 1.0).astype(np.float32) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 209e928f..f545a993 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -26,7 +26,9 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca if process_caption: caption = "-" + shared.interrogator.generate_caption(image) else: - caption = "" + caption = filename + caption = os.path.splitext(caption)[0] + caption = os.path.basename(caption) image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png")) subindex[0] += 1 -- cgit v1.2.3 From eeab7aedf532680a6ae9058ee272450bb07e41eb Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 4 Oct 2022 04:24:35 -0400 Subject: Add --use-cpu command line option Remove MPS detection to use CPU for GFPGAN / CodeFormer and add a --use-cpu command line option. --- modules/devices.py | 5 ++--- modules/esrgan_model.py | 9 ++++----- modules/scunet_model.py | 8 ++++---- modules/shared.py | 9 +++++++-- 4 files changed, 17 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 5d9c7a07..b5a0cd29 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,8 +1,8 @@ import torch -# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility from modules import errors +# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility has_mps = getattr(torch, 'has_mps', False) cpu = torch.device("cpu") @@ -32,8 +32,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = get_optimal_device() -device_gfpgan = device_codeformer = cpu if device.type == 'mps' else device +device = device_gfpgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() dtype = torch.float16 def randn(seed, shape): diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 4aed9283..d17e730f 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -6,8 +6,7 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url import modules.esrgam_model_arch as arch -from modules import shared, modelloader, images -from modules.devices import has_mps +from modules import shared, modelloader, images, devices from modules.paths import models_path from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts @@ -97,7 +96,7 @@ class UpscalerESRGAN(Upscaler): model = self.load_model(selected_model) if model is None: return img - model.to(shared.device) + model.to(devices.device_esrgan) img = esrgan_upscale(model, img) return img @@ -112,7 +111,7 @@ class UpscalerESRGAN(Upscaler): print("Unable to load %s from %s" % (self.model_path, filename)) return None - pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) + pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) pretrained_net = fix_model_layers(crt_model, pretrained_net) @@ -127,7 +126,7 @@ def upscale_without_tiling(model, img): img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(shared.device) + img = img.unsqueeze(0).to(devices.device_esrgan) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/scunet_model.py b/modules/scunet_model.py index 7987ac14..fb64b740 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -8,7 +8,7 @@ import torch from basicsr.utils.download_util import load_file_from_url import modules.upscaler -from modules import shared, modelloader +from modules import devices, modelloader from modules.paths import models_path from modules.scunet_model_arch import SCUNet as net @@ -51,12 +51,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler): if model is None: return img - device = shared.device + device = devices.device_scunet img = np.array(img) img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(shared.device) + img = img.unsqueeze(0).to(device) img = img.to(device) with torch.no_grad(): @@ -69,7 +69,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): return PIL.Image.fromarray(output, 'RGB') def load_model(self, path: str): - device = shared.device + device = devices.device_scunet if "http" in path: filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, progress=True) diff --git a/modules/shared.py b/modules/shared.py index 2a599e9c..7899ab8d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -12,7 +12,7 @@ import modules.interrogate import modules.memmon import modules.sd_models import modules.styles -from modules.devices import get_optimal_device +import modules.devices as devices from modules.paths import script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -46,6 +46,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") +parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU for specified modules", default=[]) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -63,7 +64,11 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print cmd_opts = parser.parse_args() -device = get_optimal_device() + +devices.device, devices.device_gfpgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ +(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'ESRGAN', 'SCUNet', 'CodeFormer']) + +device = devices.device batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram -- cgit v1.2.3 From 27ddc24fdee1fbe709054a43235ab7f9c51b3e9f Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 4 Oct 2022 05:18:17 -0400 Subject: Add BSRGAN to --add-cpu --- modules/bsrgan_model.py | 6 +++--- modules/devices.py | 2 +- modules/shared.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py index e62c6657..3bd80791 100644 --- a/modules/bsrgan_model.py +++ b/modules/bsrgan_model.py @@ -8,7 +8,7 @@ import torch from basicsr.utils.download_util import load_file_from_url import modules.upscaler -from modules import shared, modelloader +from modules import devices, modelloader from modules.bsrgan_model_arch import RRDBNet from modules.paths import models_path @@ -44,13 +44,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler): model = self.load_model(selected_file) if model is None: return img - model.to(shared.device) + model.to(devices.device_bsrgan) torch.cuda.empty_cache() img = np.array(img) img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(shared.device) + img = img.unsqueeze(0).to(devices.device_bsrgan) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/devices.py b/modules/devices.py index b5a0cd29..b7899632 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -32,7 +32,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = device_gfpgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() +device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() dtype = torch.float16 def randn(seed, shape): diff --git a/modules/shared.py b/modules/shared.py index 7899ab8d..95b98a06 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -46,7 +46,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU for specified modules", default=[]) +parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU for specified modules", default=[]) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -65,8 +65,8 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print cmd_opts = parser.parse_args() -devices.device, devices.device_gfpgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ -(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'ESRGAN', 'SCUNet', 'CodeFormer']) +devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ +(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer']) device = devices.device -- cgit v1.2.3 From dc9c5a97742e3a34d37da7108642d8adc0dc5858 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 4 Oct 2022 05:22:50 -0400 Subject: Modify --add-cpu description --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 95b98a06..25aff5b0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -46,7 +46,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU for specified modules", default=[]) +parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[]) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) -- cgit v1.2.3 From 6c6ae28bf5fd1e8bc3e8f64a3430b6f29f338f77 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 12:32:22 +0300 Subject: send all three of GFPGAN's and codeformer's models to CPU memory instead of just one for #1283 --- modules/codeformer_model.py | 12 ++++++++++-- modules/devices.py | 10 ++++++++++ modules/gfpgan_model.py | 14 ++++++++++++-- modules/processing.py | 16 +++++++++------- 4 files changed, 41 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index a29f3855..e6d9fa4f 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -69,10 +69,14 @@ def setup_model(dirname): self.net = net self.face_helper = face_helper - self.net.to(devices.device_codeformer) return net, face_helper + def send_model_to(self, device): + self.net.to(device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + def restore(self, np_image, w=None): np_image = np_image[:, :, ::-1] @@ -82,6 +86,8 @@ def setup_model(dirname): if self.net is None or self.face_helper is None: return np_image + self.send_model_to(devices.device_codeformer) + self.face_helper.clean_all() self.face_helper.read_image(np_image) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) @@ -113,8 +119,10 @@ def setup_model(dirname): if original_resolution != restored_img.shape[0:2]: restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) + self.face_helper.clean_all() + if shared.opts.face_restoration_unload: - self.net.to(devices.cpu) + self.send_model_to(devices.cpu) return restored_img diff --git a/modules/devices.py b/modules/devices.py index ff82f2f6..12aab665 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,3 +1,5 @@ +import contextlib + import torch # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility @@ -57,3 +59,11 @@ def randn_without_seed(shape): return torch.randn(shape, device=device) + +def autocast(): + from modules import shared + + if dtype == torch.float32 or shared.cmd_opts.precision == "full": + return contextlib.nullcontext() + + return torch.autocast("cuda") diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index dd3fbcab..5586b554 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -37,22 +37,32 @@ def gfpgann(): print("Unable to load gfpgan model!") return None model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) - model.gfpgan.to(shared.device) loaded_gfpgan_model = model return model +def send_model_to(model, device): + model.gfpgan.to(device) + model.face_helper.face_det.to(device) + model.face_helper.face_parse.to(device) + + def gfpgan_fix_faces(np_image): model = gfpgann() if model is None: return np_image + + send_model_to(model, devices.device) + np_image_bgr = np_image[:, :, ::-1] cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) np_image = gfpgan_output_bgr[:, :, ::-1] + model.face_helper.clean_all() + if shared.opts.face_restoration_unload: - model.gfpgan.to(devices.cpu) + send_model_to(model, devices.cpu) return np_image diff --git a/modules/processing.py b/modules/processing.py index 0a4b6198..9cbecdd8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,4 +1,3 @@ -import contextlib import json import math import os @@ -330,9 +329,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] - precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext - ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + + with torch.no_grad(): p.init(all_prompts, all_seeds, all_subseeds) if state.job_count == -1: @@ -351,8 +349,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #c = p.sd_model.get_learned_conditioning(prompts) - uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) - c = prompt_parser.get_learned_conditioning(prompts, p.steps) + with devices.autocast(): + uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) + c = prompt_parser.get_learned_conditioning(prompts, p.steps) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: @@ -361,7 +360,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + with devices.autocast(): + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength).to(devices.dtype) + if state.interrupted: # if we are interruped, sample returns just noise @@ -386,6 +387,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() x_sample = modules.face_restoration.restore_faces(x_sample) + devices.torch_gc() image = Image.fromarray(x_sample) -- cgit v1.2.3 From 2f1b61d97987ae0a52a7dfc6bc99c68928bdb594 Mon Sep 17 00:00:00 2001 From: dan Date: Mon, 3 Oct 2022 19:25:36 +0800 Subject: Allow nested structures inside schedules --- modules/prompt_parser.py | 119 +++++++++++++++++++++-------------------------- 1 file changed, 53 insertions(+), 66 deletions(-) (limited to 'modules') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index e811eb9e..99c8ed99 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,20 +1,11 @@ import re from collections import namedtuple import torch +from lark import Lark, Transformer, Visitor +import functools import modules.shared as shared -re_prompt = re.compile(r''' -(.*?) -\[ - ([^]:]+): - (?:([^]:]*):)? - ([0-9]*\.?[0-9]+) -] -| -(.+) -''', re.X) - # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] @@ -25,61 +16,57 @@ re_prompt = re.compile(r''' def get_learned_conditioning_prompt_schedules(prompts, steps): - res = [] - cache = {} - - for prompt in prompts: - prompt_schedule: list[list[str | int]] = [[steps, ""]] - - cached = cache.get(prompt, None) - if cached is not None: - res.append(cached) - continue - - for m in re_prompt.finditer(prompt): - plaintext = m.group(1) if m.group(5) is None else m.group(5) - concept_from = m.group(2) - concept_to = m.group(3) - if concept_to is None: - concept_to = concept_from - concept_from = "" - swap_position = float(m.group(4)) if m.group(4) is not None else None - - if swap_position is not None: - if swap_position < 1: - swap_position = swap_position * steps - swap_position = int(min(swap_position, steps)) - - swap_index = None - found_exact_index = False - for i in range(len(prompt_schedule)): - end_step = prompt_schedule[i][0] - prompt_schedule[i][1] += plaintext - - if swap_position is not None and swap_index is None: - if swap_position == end_step: - swap_index = i - found_exact_index = True - - if swap_position < end_step: - swap_index = i - - if swap_index is not None: - if not found_exact_index: - prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]]) - - for i in range(len(prompt_schedule)): - end_step = prompt_schedule[i][0] - must_replace = swap_position < end_step - - prompt_schedule[i][1] += concept_to if must_replace else concept_from - - res.append(prompt_schedule) - cache[prompt] = prompt_schedule - #for t in prompt_schedule: - # print(t) - - return res + grammar = r""" + start: prompt + prompt: (emphasized | scheduled | weighted | plain)* + !emphasized: "(" prompt ")" + | "(" prompt ":" prompt ")" + | "[" prompt "]" + scheduled: "[" (prompt ":")? prompt ":" NUMBER "]" + !weighted: "{" weighted_item ("|" weighted_item)* "}" + !weighted_item: prompt (":" prompt)? + plain: /([^\\\[\](){}:|]|\\.)+/ + %import common.SIGNED_NUMBER -> NUMBER + """ + parser = Lark(grammar, parser='lalr') + def collect_steps(steps, tree): + l = [steps] + class CollectSteps(Visitor): + def scheduled(self, tree): + tree.children[-1] = float(tree.children[-1]) + if tree.children[-1] < 1: + tree.children[-1] *= steps + tree.children[-1] = min(steps, int(tree.children[-1])) + l.append(tree.children[-1]) + CollectSteps().visit(tree) + return sorted(set(l)) + def at_step(step, tree): + class AtStep(Transformer): + def scheduled(self, args): + if len(args) == 2: + before, after, when = (), *args + else: + before, after, when = args + yield before if step <= when else after + def start(self, args): + def flatten(x): + if type(x) == str: + yield x + else: + for gen in x: + yield from flatten(gen) + return ''.join(flatten(args[0])) + def plain(self, args): + yield args[0].value + def __default__(self, data, children, meta): + for child in children: + yield from child + return AtStep().transform(tree) + @functools.cache + def get_schedule(prompt): + tree = parser.parse(prompt) + return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)] + return [get_schedule(prompt) for prompt in prompts] ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) -- cgit v1.2.3 From 61652461242951966e5b4cee83ce359cefa91c17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 14:23:22 +0300 Subject: support interrupting after the previous change --- modules/processing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 9cbecdd8..6f5599c7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -361,7 +361,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: shared.state.job = f"Batch {n+1} out of {p.n_iter}" with devices.autocast(): - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength).to(devices.dtype) + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) if state.interrupted: @@ -369,6 +369,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent + samples_ddim = samples_ddim.to(devices.dtype) + x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) -- cgit v1.2.3 From d5bba20a58f43a9f984bb67b4e17f48661f6b818 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 14:35:12 +0300 Subject: ignore errors in parse for purposes of token counting for #1564 --- modules/ui.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 55f7aa95..20dc8c37 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -386,14 +386,22 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: outputs=[seed, dummy_component] ) + def update_token_counter(text, steps): - prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) + try: + prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step,prompt_text in flat_prompts] + prompts = [prompt_text for step, prompt_text in flat_prompts] tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) style_class = ' class="red"' if (token_count > max_length) else "" return f"{token_count}/{max_length}" + def create_toprow(is_img2img): id_part = "img2img" if is_img2img else "txt2img" -- cgit v1.2.3 From accd00d6b8258c12b5168918a4c546b02357924a Mon Sep 17 00:00:00 2001 From: Justin Riddiough Date: Tue, 4 Oct 2022 01:14:28 -0500 Subject: Explain how to use second progress bar in pycharm --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 25aff5b0..11bdf01a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -200,7 +200,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration" options_templates.update(options_section(('system', "System"), { "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), - "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."), + "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. In PyCharm select 'emulate terminal in console output'."), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { -- cgit v1.2.3 From ea6b0d98a64290a0305e27126ea59ce1da7959a2 Mon Sep 17 00:00:00 2001 From: Justin Riddiough Date: Tue, 4 Oct 2022 06:38:45 -0500 Subject: Remove pycharm note, fix typo --- modules/shared.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 11bdf01a..a7d13b2d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -200,7 +200,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration" options_templates.update(options_section(('system', "System"), { "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), - "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. In PyCharm select 'emulate terminal in console output'."), + "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { @@ -209,7 +209,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), - "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), -- cgit v1.2.3 From eec1b39bd54711ca31e43022d2d6ac8c6d7281da Mon Sep 17 00:00:00 2001 From: Milly Date: Tue, 4 Oct 2022 20:16:52 +0900 Subject: Apply prompt pattern last --- modules/images.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index bba55158..5b56c7e3 100644 --- a/modules/images.py +++ b/modules/images.py @@ -287,6 +287,25 @@ def apply_filename_pattern(x, p, seed, prompt): if seed is not None: x = x.replace("[seed]", str(seed)) + if p is not None: + x = x.replace("[steps]", str(p.steps)) + x = x.replace("[cfg]", str(p.cfg_scale)) + x = x.replace("[width]", str(p.width)) + x = x.replace("[height]", str(p.height)) + + #currently disabled if using the save button, will work otherwise + # if enabled it will cause a bug because styles is not included in the save_files data dictionary + if hasattr(p, "styles"): + x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"] or "None"), replace_spaces=False)) + + x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) + + x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) + x = x.replace("[date]", datetime.date.today().isoformat()) + x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S")) + x = x.replace("[job_timestamp]", shared.state.job_timestamp) + + # Apply [prompt] at last. Because it may contain any replacement word.^M if prompt is not None: x = x.replace("[prompt]", sanitize_filename_part(prompt)) if "[prompt_no_styles]" in x: @@ -295,7 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt): if len(style) > 0: style_parts = [y for y in style.split("{prompt}")] for part in style_parts: - prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',') + prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',') prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip() x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False)) @@ -306,24 +325,6 @@ def apply_filename_pattern(x, p, seed, prompt): words = ["empty"] x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False)) - if p is not None: - x = x.replace("[steps]", str(p.steps)) - x = x.replace("[cfg]", str(p.cfg_scale)) - x = x.replace("[width]", str(p.width)) - x = x.replace("[height]", str(p.height)) - - #currently disabled if using the save button, will work otherwise - # if enabled it will cause a bug because styles is not included in the save_files data dictionary - if hasattr(p, "styles"): - x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"] or "None"), replace_spaces=False)) - - x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) - - x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) - x = x.replace("[date]", datetime.date.today().isoformat()) - x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S")) - x = x.replace("[job_timestamp]", shared.state.job_timestamp) - if cmd_opts.hide_ui_dir_config: x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x) -- cgit v1.2.3 From 52cef36f6ba169a8e606ecdcaed73d47378f0e8e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 16:54:31 +0300 Subject: emergency fix for img2img --- modules/processing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 6f5599c7..e9c45394 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -331,7 +331,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: output_images = [] with torch.no_grad(): - p.init(all_prompts, all_seeds, all_subseeds) + with devices.autocast(): + p.init(all_prompts, all_seeds, all_subseeds) if state.job_count == -1: state.job_count = p.n_iter -- cgit v1.2.3 From 957e29a8e9cb8ca069799ec69263e188c89ed6a6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 17:23:48 +0300 Subject: option to not show images in web ui --- modules/img2img.py | 3 +++ modules/shared.py | 1 + modules/txt2img.py | 3 +++ 3 files changed, 7 insertions(+) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 2ff8e261..da212d72 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -129,4 +129,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro if opts.samples_log_stdout: print(generation_info_js) + if opts.do_not_show_images: + processed.images = [] + return processed.images, generation_info_js, plaintext_to_html(processed.info) diff --git a/modules/shared.py b/modules/shared.py index a7d13b2d..ff4e5fa3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -229,6 +229,7 @@ options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "return_grid": OptionInfo(True, "Show grid in results for web"), + "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), diff --git a/modules/txt2img.py b/modules/txt2img.py index d4406c3c..e985242b 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -48,5 +48,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: if opts.samples_log_stdout: print(generation_info_js) + if opts.do_not_show_images: + processed.images = [] + return processed.images, generation_info_js, plaintext_to_html(processed.info) -- cgit v1.2.3 From e1b128d8e46bddb9c0b2fd3ee0eefd57e0527ee0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 17:36:39 +0300 Subject: do not touch p.seed/p.subseed during processing #1181 --- modules/processing.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index e9c45394..8180c63d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -248,9 +248,16 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x +def get_fixed_seed(seed): + if seed is None or seed == '' or seed == -1: + return int(random.randrange(4294967294)) + + return seed + + def fix_seed(p): - p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed - p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed + p.seed = get_fixed_seed(p.seed) + p.subseed = get_fixed_seed(p.subseed) def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): @@ -292,7 +299,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() - fix_seed(p) + seed = get_fixed_seed(p.seed) + subseed = get_fixed_seed(p.subseed) if p.outpath_samples is not None: os.makedirs(p.outpath_samples, exist_ok=True) @@ -311,15 +319,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed: else: all_prompts = p.batch_size * p.n_iter * [p.prompt] - if type(p.seed) == list: - all_seeds = p.seed + if type(seed) == list: + all_seeds = seed else: - all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] + all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] - if type(p.subseed) == list: - all_subseeds = p.subseed + if type(subseed) == list: + all_subseeds = subseed else: - all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))] + all_subseeds = [int(subseed) + x for x in range(len(all_prompts))] def infotext(iteration=0, position_in_batch=0): return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) -- cgit v1.2.3 From 1eb588cbf19924333b88beaa1ac0041904966640 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 18:02:01 +0300 Subject: remove functools.cache as some people are having issues with it --- modules/prompt_parser.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 99c8ed99..5d58c4ed 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -29,6 +29,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): %import common.SIGNED_NUMBER -> NUMBER """ parser = Lark(grammar, parser='lalr') + def collect_steps(steps, tree): l = [steps] class CollectSteps(Visitor): @@ -40,6 +41,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): l.append(tree.children[-1]) CollectSteps().visit(tree) return sorted(set(l)) + def at_step(step, tree): class AtStep(Transformer): def scheduled(self, args): @@ -62,11 +64,13 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): for child in children: yield from child return AtStep().transform(tree) - @functools.cache + def get_schedule(prompt): tree = parser.parse(prompt) return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)] - return [get_schedule(prompt) for prompt in prompts] + + promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)} + return [promptdict[prompt] for prompt in prompts] ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) -- cgit v1.2.3 From 90e911fd546e76f879b38a764473569911a0f845 Mon Sep 17 00:00:00 2001 From: Rae Fu Date: Tue, 4 Oct 2022 09:49:51 -0600 Subject: prompt_parser: allow spaces in schedules, add test, log/ignore errors Only build the parser once (at import time) instead of for each step. doctest is run by simply executing modules/prompt_parser.py --- modules/processing.py | 10 ++-- modules/prompt_parser.py | 139 ++++++++++++++++++++++++++++++----------------- 2 files changed, 95 insertions(+), 54 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 8180c63d..bb94033b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -84,7 +84,7 @@ class StableDiffusionProcessing: self.s_tmin = opts.s_tmin self.s_tmax = float('inf') # not representable as a standard ui option self.s_noise = opts.s_noise - + if not seed_enable_extras: self.subseed = -1 self.subseed_strength = 0 @@ -296,7 +296,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: assert(len(p.prompt) > 0) else: assert p.prompt is not None - + devices.torch_gc() seed = get_fixed_seed(p.seed) @@ -359,8 +359,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) - c = prompt_parser.get_learned_conditioning(prompts, p.steps) + uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) + c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: @@ -527,7 +527,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() - + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) return samples diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 5d58c4ed..a3b12421 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,10 +1,7 @@ import re from collections import namedtuple -import torch -from lark import Lark, Transformer, Visitor -import functools -import modules.shared as shared +import lark # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): @@ -14,25 +11,48 @@ import modules.shared as shared # [75, 'fantasy landscape with a lake and an oak in background masterful'] # [100, 'fantasy landscape with a lake and a christmas tree in background masterful'] +schedule_parser = lark.Lark(r""" +!start: (prompt | /[][():]/+)* +prompt: (emphasized | scheduled | plain | WHITESPACE)* +!emphasized: "(" prompt ")" + | "(" prompt ":" prompt ")" + | "[" prompt "]" +scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" +WHITESPACE: /\s+/ +plain: /([^\\\[\]():]|\\.)+/ +%import common.SIGNED_NUMBER -> NUMBER +""") def get_learned_conditioning_prompt_schedules(prompts, steps): - grammar = r""" - start: prompt - prompt: (emphasized | scheduled | weighted | plain)* - !emphasized: "(" prompt ")" - | "(" prompt ":" prompt ")" - | "[" prompt "]" - scheduled: "[" (prompt ":")? prompt ":" NUMBER "]" - !weighted: "{" weighted_item ("|" weighted_item)* "}" - !weighted_item: prompt (":" prompt)? - plain: /([^\\\[\](){}:|]|\\.)+/ - %import common.SIGNED_NUMBER -> NUMBER """ - parser = Lark(grammar, parser='lalr') + >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0] + >>> g("test") + [[10, 'test']] + >>> g("a [b:3]") + [[3, 'a '], [10, 'a b']] + >>> g("a [b: 3]") + [[3, 'a '], [10, 'a b']] + >>> g("a [[[b]]:2]") + [[2, 'a '], [10, 'a [[b]]']] + >>> g("[(a:2):3]") + [[3, ''], [10, '(a:2)']] + >>> g("a [b : c : 1] d") + [[1, 'a b d'], [10, 'a c d']] + >>> g("a[b:[c:d:2]:1]e") + [[1, 'abe'], [2, 'ace'], [10, 'ade']] + >>> g("a [unbalanced") + [[10, 'a [unbalanced']] + >>> g("a [b:.5] c") + [[5, 'a c'], [10, 'a b c']] + >>> g("a [{b|d{:.5] c") # not handling this right now + [[5, 'a c'], [10, 'a {b|d{ c']] + >>> g("((a][:b:c [d:3]") + [[3, '((a][:b:c '], [10, '((a][:b:c d']] + """ def collect_steps(steps, tree): l = [steps] - class CollectSteps(Visitor): + class CollectSteps(lark.Visitor): def scheduled(self, tree): tree.children[-1] = float(tree.children[-1]) if tree.children[-1] < 1: @@ -43,13 +63,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): return sorted(set(l)) def at_step(step, tree): - class AtStep(Transformer): + class AtStep(lark.Transformer): def scheduled(self, args): - if len(args) == 2: - before, after, when = (), *args - else: - before, after, when = args - yield before if step <= when else after + before, after, _, when = args + yield before or () if step <= when else after def start(self, args): def flatten(x): if type(x) == str: @@ -57,16 +74,22 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): else: for gen in x: yield from flatten(gen) - return ''.join(flatten(args[0])) + return ''.join(flatten(args)) def plain(self, args): yield args[0].value def __default__(self, data, children, meta): for child in children: yield from child return AtStep().transform(tree) - + def get_schedule(prompt): - tree = parser.parse(prompt) + try: + tree = schedule_parser.parse(prompt) + except lark.exceptions.LarkError as e: + if 0: + import traceback + traceback.print_exc() + return [[steps, prompt]] return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)] promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)} @@ -77,8 +100,7 @@ ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"]) -def get_learned_conditioning(prompts, steps): - +def get_learned_conditioning(model, prompts, steps): res = [] prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) @@ -92,7 +114,7 @@ def get_learned_conditioning(prompts, steps): continue texts = [x[1] for x in prompt_schedule] - conds = shared.sd_model.get_learned_conditioning(texts) + conds = model.get_learned_conditioning(texts) cond_schedule = [] for i, (end_at_step, text) in enumerate(prompt_schedule): @@ -105,12 +127,13 @@ def get_learned_conditioning(prompts, steps): def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): - res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype) + param = c.schedules[0][0].cond + res = torch.zeros(c.shape, device=param.device, dtype=param.dtype) for i, cond_schedule in enumerate(c.schedules): target_index = 0 - for curret_index, (end_at, cond) in enumerate(cond_schedule): + for current, (end_at, cond) in enumerate(cond_schedule): if current_step <= end_at: - target_index = curret_index + target_index = current break res[i] = cond_schedule[target_index].cond @@ -148,23 +171,26 @@ def parse_prompt_attention(text): \\ - literal character '\' anything else - just text - Example: - - 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).' - - produces: - - [ - ['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1] - ] + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] """ res = [] @@ -206,4 +232,19 @@ def parse_prompt_attention(text): if len(res) == 0: res = [["", 1.0]] + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + return res + +if __name__ == "__main__": + import doctest + doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) +else: + import torch # doctest faster -- cgit v1.2.3 From b32852ef037251eb3d846af76e2965594e1ac7a5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 20:49:54 +0300 Subject: add editor to img2img --- modules/shared.py | 1 + modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index ff4e5fa3..e52c9b1d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -55,6 +55,7 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="color-sketch") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) diff --git a/modules/ui.py b/modules/ui.py index 20dc8c37..6cd6761b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -644,7 +644,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: with gr.TabItem('img2img', id='img2img'): - init_img = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil") + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool) with gr.TabItem('Inpaint', id='inpaint'): init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA") -- cgit v1.2.3 From ef40e4cd4d383a3405e03f1da3f5b5a1820a8f53 Mon Sep 17 00:00:00 2001 From: xpscyho Date: Tue, 4 Oct 2022 15:12:38 -0400 Subject: Display time taken in mins, secs when relevant Fixes #1656 --- modules/ui.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 6cd6761b..de6342a4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -196,6 +196,11 @@ def wrap_gradio_call(func, extra_outputs=None): res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] elapsed = time.perf_counter() - t + elapsed_m = int(elapsed // 60) + elapsed_s = elapsed % 60 + elapsed_text = f"{elapsed_s:.2f}s" + if (elapsed_m > 0): + elapsed_text = f"{elapsed_m}m "+elapsed_text if run_memmon: mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} @@ -210,7 +215,7 @@ def wrap_gradio_call(func, extra_outputs=None): vram_html = '' # last item is always HTML - res[-1] += f"

Time taken: {elapsed:.2f}s

{vram_html}
" + res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" shared.state.interrupted = False shared.state.job_count = 0 -- cgit v1.2.3 From 82380d9ac18614c87bebba1b4cfd4b147cc76a18 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Tue, 4 Oct 2022 22:28:50 -0300 Subject: Removing parts no longer needed to fix vram --- modules/devices.py | 3 +-- modules/processing.py | 21 ++++++++------------- 2 files changed, 9 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 6db4e57c..0158b11f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,7 +1,6 @@ import contextlib import torch -import gc from modules import errors @@ -20,8 +19,8 @@ def get_optimal_device(): return cpu + def torch_gc(): - gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/modules/processing.py b/modules/processing.py index e7f9c85e..f666ba81 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -345,8 +345,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter - for n in range(p.n_iter): - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + for n in range(p.n_iter): if state.interrupted: break @@ -395,22 +394,19 @@ def process_images(p: StableDiffusionProcessing) -> Processed: import modules.safety as safety x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) - for i, x_sample in enumerate(x_samples_ddim): - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) - if p.restore_faces: - with torch.no_grad(), precision_scope("cuda"), ema_scope(): + if p.restore_faces: if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") - x_sample = modules.face_restoration.restore_faces(x_sample) devices.torch_gc() - devices.torch_gc() + x_sample = modules.face_restoration.restore_faces(x_sample) + devices.torch_gc() - with torch.no_grad(), precision_scope("cuda"), ema_scope(): image = Image.fromarray(x_sample) if p.color_corrections is not None and i < len(p.color_corrections): @@ -438,13 +434,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts.append(infotext(n, i)) output_images.append(image) - del x_samples_ddim + del x_samples_ddim - devices.torch_gc() + devices.torch_gc() - state.nextjob() + state.nextjob() - with torch.no_grad(), precision_scope("cuda"), ema_scope(): p.color_corrections = None index_of_first_image = 0 -- cgit v1.2.3 From bbdbbd36eda870cf0bd49fdf28476c78919a123e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 5 Oct 2022 04:43:05 +0100 Subject: shared.state.interrupt when restart is requested --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index de6342a4..523ab25b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1210,6 +1210,7 @@ def create_ui(wrap_gradio_gpu_call): ) def request_restart(): + shared.state.interrupt() settings_interface.gradio_ref.do_restart = True restart_gradio.click( -- cgit v1.2.3 From 59a2b9e5afc27d2fda72069ca0635070535d18fe Mon Sep 17 00:00:00 2001 From: Greendayle Date: Wed, 5 Oct 2022 20:50:10 +0200 Subject: deepdanbooru interrogator --- modules/deepbooru.py | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++ modules/ui.py | 24 ++++++++++++++++----- 2 files changed, 79 insertions(+), 5 deletions(-) create mode 100644 modules/deepbooru.py (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py new file mode 100644 index 00000000..958b1c3d --- /dev/null +++ b/modules/deepbooru.py @@ -0,0 +1,60 @@ +import os.path +from concurrent.futures import ProcessPoolExecutor + +import numpy as np +import deepdanbooru as dd +import tensorflow as tf + + +def _load_tf_and_return_tags(pil_image, threshold): + this_folder = os.path.dirname(__file__) + model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28') + if not os.path.exists(model_path): + return "Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru" + + tags = dd.project.load_tags_from_project(model_path) + model = dd.project.load_model_from_project( + model_path, compile_model=True + ) + + width = model.input_shape[2] + height = model.input_shape[1] + image = np.array(pil_image) + image = tf.image.resize( + image, + size=(height, width), + method=tf.image.ResizeMethod.AREA, + preserve_aspect_ratio=True, + ) + image = image.numpy() # EagerTensor to np.array + image = dd.image.transform_and_pad_image(image, width, height) + image = image / 255.0 + image_shape = image.shape + image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2])) + + y = model.predict(image)[0] + + result_dict = {} + + for i, tag in enumerate(tags): + result_dict[tag] = y[i] + + + + result_tags_out = [] + result_tags_print = [] + for tag in tags: + if result_dict[tag] >= threshold: + result_tags_out.append(tag) + result_tags_print.append(f'{result_dict[tag]} {tag}') + + print('\n'.join(sorted(result_tags_print, reverse=True))) + + return ', '.join(result_tags_out) + + +def get_deepbooru_tags(pil_image, threshold=0.5): + with ProcessPoolExecutor() as executor: + f = executor.submit(_load_tf_and_return_tags, pil_image, threshold) + ret = f.result() # will rethrow any exceptions + return ret \ No newline at end of file diff --git a/modules/ui.py b/modules/ui.py index 20dc8c37..ae98219a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,6 +23,7 @@ import gradio.utils import gradio.routes from modules import sd_hijack +from modules.deepbooru import get_deepbooru_tags from modules.paths import script_path from modules.shared import opts, cmd_opts import modules.shared as shared @@ -312,6 +313,11 @@ def interrogate(image): return gr_show(True) if prompt is None else prompt +def interrogate_deepbooru(image): + prompt = get_deepbooru_tags(image) + return gr_show(True) if prompt is None else prompt + + def create_seed_inputs(): with gr.Row(): with gr.Box(): @@ -439,15 +445,17 @@ def create_toprow(is_img2img): outputs=[], ) - with gr.Row(): + with gr.Row(scale=1): if is_img2img: - interrogate = gr.Button('Interrogate', elem_id="interrogate") + interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") else: interrogate = None + deepbooru = None prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") save_style = gr.Button('Create style', elem_id="style_create") - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -476,7 +484,7 @@ def create_ui(wrap_gradio_gpu_call): import modules.txt2img with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) with gr.Row(elem_id='txt2img_progress_row'): @@ -628,7 +636,7 @@ def create_ui(wrap_gradio_gpu_call): token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): with gr.Column(scale=1): @@ -785,6 +793,12 @@ def create_ui(wrap_gradio_gpu_call): outputs=[img2img_prompt], ) + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) + save.click( fn=wrap_gradio_call(save_files), _js="(x, y, z) => [x, y, selected_gallery_index()]", -- cgit v1.2.3 From 1506fab29ad54beb9f52236912abc432209c8089 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Wed, 5 Oct 2022 21:15:08 +0200 Subject: removing problematic tag --- modules/deepbooru.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 958b1c3d..841cb9c5 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -38,13 +38,12 @@ def _load_tf_and_return_tags(pil_image, threshold): for i, tag in enumerate(tags): result_dict[tag] = y[i] - - - result_tags_out = [] result_tags_print = [] for tag in tags: if result_dict[tag] >= threshold: + if tag.startswith("rating:"): + continue result_tags_out.append(tag) result_tags_print.append(f'{result_dict[tag]} {tag}') -- cgit v1.2.3 From 17a99baf0c929e5df4dfc4b2a96aa3890a141112 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Wed, 5 Oct 2022 22:05:24 +0200 Subject: better model search --- modules/deepbooru.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 841cb9c5..a64fd9cd 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -9,8 +9,15 @@ import tensorflow as tf def _load_tf_and_return_tags(pil_image, threshold): this_folder = os.path.dirname(__file__) model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28') - if not os.path.exists(model_path): - return "Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru" + + model_good = False + for path_candidate in [model_path, os.path.dirname(model_path)]: + if os.path.exists(os.path.join(path_candidate, 'project.json')): + model_path = path_candidate + model_good = True + if not model_good: + return ("Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/" + "deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru") tags = dd.project.load_tags_from_project(model_path) model = dd.project.load_model_from_project( -- cgit v1.2.3 From c26732fbee2a57e621ac22bf70decf7496daa4cd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 5 Oct 2022 23:16:27 +0300 Subject: added support for AND from https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/ --- modules/processing.py | 2 +- modules/prompt_parser.py | 114 ++++++++++++++++++++++++++++++++++++++++++++--- modules/sd_samplers.py | 35 ++++++++++----- modules/ui.py | 6 ++- 4 files changed, 138 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index bb94033b..d8c6b8d5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -360,7 +360,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) - c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps) + c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index a3b12421..f7420daf 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -97,10 +97,26 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) -ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"]) def get_learned_conditioning(model, prompts, steps): + """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), + and the sampling step at which this condition is to be replaced by the next one. + + Input: + (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20) + + Output: + [ + [ + ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0')) + ], + [ + ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')), + ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0')) + ] + ] + """ res = [] prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) @@ -123,13 +139,75 @@ def get_learned_conditioning(model, prompts, steps): cache[prompt] = cond_schedule res.append(cond_schedule) - return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res) + return res + + +re_AND = re.compile(r"\bAND\b") +re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?\s*(?:\d+|\d*\.\d+)?))?\s*$") + + +def get_multicond_prompt_list(prompts): + res_indexes = [] + + prompt_flat_list = [] + prompt_indexes = {} + + for prompt in prompts: + subprompts = re_AND.split(prompt) + + indexes = [] + for subprompt in subprompts: + text, weight = re_weight.search(subprompt).groups() + + weight = float(weight) if weight is not None else 1.0 + + index = prompt_indexes.get(text, None) + if index is None: + index = len(prompt_flat_list) + prompt_flat_list.append(text) + prompt_indexes[text] = index + + indexes.append((index, weight)) + + res_indexes.append(indexes) + + return res_indexes, prompt_flat_list, prompt_indexes + + +class ComposableScheduledPromptConditioning: + def __init__(self, schedules, weight=1.0): + self.schedules: list[ScheduledPromptConditioning] = schedules + self.weight: float = weight + + +class MulticondLearnedConditioning: + def __init__(self, shape, batch): + self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS + self.batch: list[list[ComposableScheduledPromptConditioning]] = batch -def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): - param = c.schedules[0][0].cond - res = torch.zeros(c.shape, device=param.device, dtype=param.dtype) - for i, cond_schedule in enumerate(c.schedules): +def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: + """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. + For each prompt, the list is obtained by splitting the prompt using the AND separator. + + https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/ + """ + + res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts) + + learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps) + + res = [] + for indexes in res_indexes: + res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes]) + + return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) + + +def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): + param = c[0][0].cond + res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + for i, cond_schedule in enumerate(c): target_index = 0 for current, (end_at, cond) in enumerate(cond_schedule): if current_step <= end_at: @@ -140,6 +218,30 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): return res +def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): + param = c.batch[0][0].schedules[0].cond + + tensors = [] + conds_list = [] + + for batch_no, composable_prompts in enumerate(c.batch): + conds_for_batch = [] + + for cond_index, composable_prompt in enumerate(composable_prompts): + target_index = 0 + for current, (end_at, cond) in enumerate(composable_prompt.schedules): + if current_step <= end_at: + target_index = current + break + + conds_for_batch.append((len(tensors), composable_prompt.weight)) + tensors.append(composable_prompt.schedules[target_index].cond) + + conds_list.append(conds_for_batch) + + return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) + + re_attention = re.compile(r""" \\\(| \\\)| diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index dbf570d2..d27c547b 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -109,9 +109,12 @@ class VanillaStableDiffusionSampler: return 0 def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' + cond = tensor + if self.mask is not None: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec @@ -183,19 +186,31 @@ class CFGDenoiser(torch.nn.Module): self.step = 0 def forward(self, x, sigma, uncond, cond, cond_scale): - cond = prompt_parser.reconstruct_cond_batch(cond, self.step) + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + cond_in = torch.cat([tensor, uncond]) + if shared.batch_cond_uncond: - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - denoised = uncond + (cond - uncond) * cond_scale + x_out = self.inner_model(x_in, sigma_in, cond=cond_in) else: - uncond = self.inner_model(x, sigma, cond=uncond) - cond = self.inner_model(x, sigma, cond=cond) - denoised = uncond + (cond - uncond) * cond_scale + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) + + denoised_uncond = x_out[-batch_size:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) if self.mask is not None: denoised = self.init_latent * self.mask + self.nmask * denoised diff --git a/modules/ui.py b/modules/ui.py index 523ab25b..9620350f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -34,7 +34,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste -from modules.prompt_parser import get_learned_conditioning_prompt_schedules +from modules import prompt_parser from modules.images import apply_filename_pattern, get_next_sequence_number import modules.textual_inversion.ui @@ -394,7 +394,9 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def update_token_counter(text, steps): try: - prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + except Exception: # a parsing error can happen here during typing, and we don't want to bother the user with # messages related to it in console -- cgit v1.2.3 From 4320f386d9641c7c234589c4cb0c0c6cbeb156ad Mon Sep 17 00:00:00 2001 From: Greendayle Date: Wed, 5 Oct 2022 22:39:32 +0200 Subject: removing underscores and colons --- modules/deepbooru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index a64fd9cd..fb5018a6 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -56,7 +56,7 @@ def _load_tf_and_return_tags(pil_image, threshold): print('\n'.join(sorted(result_tags_print, reverse=True))) - return ', '.join(result_tags_out) + return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') def get_deepbooru_tags(pil_image, threshold=0.5): -- cgit v1.2.3 From f8e41a96bb30a04dd5e294c7e1178c1c3b09d481 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 5 Oct 2022 23:52:05 +0300 Subject: fix various float parsing errors --- modules/prompt_parser.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f7420daf..800b12c7 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -143,8 +143,7 @@ def get_learned_conditioning(model, prompts, steps): re_AND = re.compile(r"\bAND\b") -re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?\s*(?:\d+|\d*\.\d+)?))?\s*$") - +re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") def get_multicond_prompt_list(prompts): res_indexes = [] -- cgit v1.2.3 From 20f8ec877a99ce2ebf193cb1e2e773cfc77b7c41 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 00:09:32 +0300 Subject: remove type annotations in new code because presumably they don't work in 3.7 --- modules/prompt_parser.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 800b12c7..ee4c5d02 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -175,14 +175,14 @@ def get_multicond_prompt_list(prompts): class ComposableScheduledPromptConditioning: def __init__(self, schedules, weight=1.0): - self.schedules: list[ScheduledPromptConditioning] = schedules + self.schedules = schedules # : list[ScheduledPromptConditioning] self.weight: float = weight class MulticondLearnedConditioning: def __init__(self, shape, batch): self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS - self.batch: list[list[ComposableScheduledPromptConditioning]] = batch + self.batch = batch # : list[list[ComposableScheduledPromptConditioning]] def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: @@ -203,7 +203,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) -def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): +def reconstruct_cond_batch(c, current_step): # c: list[list[ScheduledPromptConditioning]] param = c[0][0].cond res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for i, cond_schedule in enumerate(c): -- cgit v1.2.3 From 34c358d10d52817f7a889ae4c52096ee654f3fe6 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 5 Oct 2022 22:11:30 +0100 Subject: use typing.list in prompt_parser.py for wider python version support --- modules/prompt_parser.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 800b12c7..fdfa21ae 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,6 +1,6 @@ import re from collections import namedtuple - +from typing import List import lark # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" @@ -175,14 +175,14 @@ def get_multicond_prompt_list(prompts): class ComposableScheduledPromptConditioning: def __init__(self, schedules, weight=1.0): - self.schedules: list[ScheduledPromptConditioning] = schedules + self.schedules: List[ScheduledPromptConditioning] = schedules self.weight: float = weight class MulticondLearnedConditioning: def __init__(self, shape, batch): self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS - self.batch: list[list[ComposableScheduledPromptConditioning]] = batch + self.batch: List[List[ComposableScheduledPromptConditioning]] = batch def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: @@ -203,7 +203,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) -def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): +def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): param = c[0][0].cond res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for i, cond_schedule in enumerate(c): -- cgit v1.2.3 From 55400c981b7c1389482057a35ed6ea11f08da194 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 03:11:15 +0100 Subject: Set gradio-img2img-tool default to 'editor' --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index e52c9b1d..bab0fe6e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -55,7 +55,7 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) -parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="color-sketch") +parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) -- cgit v1.2.3 From 2499fb4e1910d31ff12c24110f161b20641b8835 Mon Sep 17 00:00:00 2001 From: Raphael Stoeckli Date: Wed, 5 Oct 2022 21:57:18 +0200 Subject: Add sanitizer for captions in Textual inversion --- modules/textual_inversion/preprocess.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index f545a993..4f3df4bd 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,5 +1,8 @@ +from cmath import log import os from PIL import Image, ImageOps +import platform +import sys import tqdm from modules import shared, images @@ -25,6 +28,7 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca def save_pic_with_caption(image, index): if process_caption: caption = "-" + shared.interrogator.generate_caption(image) + caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png") else: caption = filename caption = os.path.splitext(caption)[0] @@ -75,3 +79,27 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca if process_caption: shared.interrogator.send_blip_to_ram() + +def sanitize_caption(base_path, original_caption, suffix): + operating_system = platform.system().lower() + if (operating_system == "windows"): + invalid_path_characters = "\\/:*?\"<>|" + max_path_length = 259 + else: + invalid_path_characters = "/" #linux/macos + max_path_length = 1023 + caption = original_caption + for invalid_character in invalid_path_characters: + caption = caption.replace(invalid_character, "") + fixed_path_length = len(base_path) + len(suffix) + if fixed_path_length + len(caption) <= max_path_length: + return caption + caption_tokens = caption.split() + new_caption = "" + for token in caption_tokens: + last_caption = new_caption + new_caption = new_caption + token + " " + if (len(new_caption) + fixed_path_length - 1 > max_path_length): + break + print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr) + return last_caption.strip() -- cgit v1.2.3 From 4288e53fc2ea25fa49715bf5b7f14603553c9e38 Mon Sep 17 00:00:00 2001 From: Raphael Stoeckli Date: Wed, 5 Oct 2022 23:11:32 +0200 Subject: removed unused import, fixed typo --- modules/textual_inversion/preprocess.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 4f3df4bd..f1c002a2 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,4 +1,3 @@ -from cmath import log import os from PIL import Image, ImageOps import platform @@ -13,7 +12,7 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) - assert src != dst, 'same directory specified as source and desitnation' + assert src != dst, 'same directory specified as source and destination' os.makedirs(dst, exist_ok=True) -- cgit v1.2.3 From 5f24b7bcf4a074fbdec757617fcd1bc82e76551b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 12:08:48 +0300 Subject: option to let users select which samplers they want to hide --- modules/processing.py | 13 ++++++------- modules/sd_samplers.py | 19 +++++++++++++++++-- modules/shared.py | 15 +++++++++------ 3 files changed, 32 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index d8c6b8d5..e01c8b3f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -11,9 +11,8 @@ import cv2 from skimage import exposure import modules.sd_hijack -from modules import devices, prompt_parser, masking +from modules import devices, prompt_parser, masking, sd_samplers from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.face_restoration @@ -110,7 +109,7 @@ class Processed: self.width = p.width self.height = p.height self.sampler_index = p.sampler_index - self.sampler = samplers[p.sampler_index].name + self.sampler = sd_samplers.samplers[p.sampler_index].name self.cfg_scale = p.cfg_scale self.steps = p.steps self.batch_size = p.batch_size @@ -265,7 +264,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params = { "Steps": p.steps, - "Sampler": samplers[p.sampler_index].name, + "Sampler": sd_samplers.samplers[p.sampler_index].name, "CFG scale": p.cfg_scale, "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), @@ -478,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.firstphase_height_truncated = int(scale * self.height) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - self.sampler = samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -521,7 +520,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - self.sampler = samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) # GC now before running the next img2img to prevent running out of memory @@ -556,7 +555,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.nmask = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model) crop_region = None if self.image_mask is not None: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d27c547b..2e1f7715 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -32,12 +32,27 @@ samplers_data_k_diffusion = [ if hasattr(k_diffusion.sampling, funcname) ] -samplers = [ +all_samplers = [ *samplers_data_k_diffusion, SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), ] -samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']] + +samplers = [] +samplers_for_img2img = [] + + +def set_samplers(): + global samplers, samplers_for_img2img + + hidden = set(opts.hide_samplers) + hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive']) + + samplers = [x for x in all_samplers if x.name not in hidden] + samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] + + +set_samplers() sampler_extra_params = { 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], diff --git a/modules/shared.py b/modules/shared.py index bab0fe6e..ca2e4c74 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,6 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices +from modules import sd_samplers from modules.paths import script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -238,14 +239,16 @@ options_templates.update(options_section(('ui', "User interface"), { })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { - "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), - 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}), + "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), + 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) + class Options: data = None data_labels = options_templates -- cgit v1.2.3 From 2d3ea42a2d1e909bbccdb6b49561b187c60a9402 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 13:21:12 +0300 Subject: workaround for a mysterious bug where prompt weights can't be matched --- modules/prompt_parser.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index a7a6aa31..f00256f2 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -156,7 +156,9 @@ def get_multicond_prompt_list(prompts): indexes = [] for subprompt in subprompts: - text, weight = re_weight.search(subprompt).groups() + match = re_weight.search(subprompt) + + text, weight = match.groups() if match is not None else (subprompt, 1.0) weight = float(weight) if weight is not None else 1.0 -- cgit v1.2.3 From 71901b3d3bea1d035bf4a7229d19356b4b062151 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Wed, 5 Oct 2022 14:30:57 +0300 Subject: add karras scheduling variants --- modules/sd_samplers.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 2e1f7715..8d6eb762 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -26,6 +26,17 @@ samplers_k_diffusion = [ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']), ] +if opts.show_karras_scheduler_variants: + k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2 + k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral + k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms + samplers_k_diffusion_ka = [ + ('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']), + ('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']), + ('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']), + ] + samplers_k_diffusion.extend(samplers_k_diffusion_ka) + samplers_data_k_diffusion = [ SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases) for label, funcname, aliases in samplers_k_diffusion @@ -345,6 +356,8 @@ class KDiffusionSampler: if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) + elif self.funcname.endswith('ka'): + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) else: sigmas = self.model_wrap.get_sigmas(steps) x = x * sigmas[0] -- cgit v1.2.3 From 3ddf80a9db8793188e2fe9488233d2b272cceb33 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Wed, 5 Oct 2022 14:31:51 +0300 Subject: add variant setting --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index ca2e4c74..9e4860a2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,6 +236,7 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "show_karras_scheduler_variants": OptionInfo(True, "Show Karras scheduling variants for select samplers. Try these variants if your K sampled images suffer from excessive noise."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From 5993df24a1026225cb8af89237547c1d9101ce69 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 14:12:52 +0300 Subject: integrate the new samplers PR --- modules/processing.py | 7 +++--- modules/sd_samplers.py | 59 ++++++++++++++++++++++++++------------------------ modules/shared.py | 1 - 3 files changed, 35 insertions(+), 32 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index e01c8b3f..e567956c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -477,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.firstphase_height_truncated = int(scale * self.height) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -520,7 +520,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) # GC now before running the next img2img to prevent running out of memory @@ -555,7 +556,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.nmask = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) crop_region = None if self.image_mask is not None: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 8d6eb762..497df943 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -13,46 +13,46 @@ from modules.shared import opts, cmd_opts, state import modules.shared as shared -SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases']) +SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a']), - ('Euler', 'sample_euler', ['k_euler']), - ('LMS', 'sample_lms', ['k_lms']), - ('Heun', 'sample_heun', ['k_heun']), - ('DPM2', 'sample_dpm_2', ['k_dpm_2']), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']), + ('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}), + ('Euler', 'sample_euler', ['k_euler'], {}), + ('LMS', 'sample_lms', ['k_lms'], {}), + ('Heun', 'sample_heun', ['k_heun'], {}), + ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}), + ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), + ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), + ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), + ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}), + ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), ] -if opts.show_karras_scheduler_variants: - k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2 - k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral - k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms - samplers_k_diffusion_ka = [ - ('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']), - ('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']), - ('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']), - ] - samplers_k_diffusion.extend(samplers_k_diffusion_ka) - samplers_data_k_diffusion = [ - SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases) - for label, funcname, aliases in samplers_k_diffusion + SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_k_diffusion if hasattr(k_diffusion.sampling, funcname) ] all_samplers = [ *samplers_data_k_diffusion, - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), + SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), + SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), ] samplers = [] samplers_for_img2img = [] +def create_sampler_with_index(list_of_configs, index, model): + config = list_of_configs[index] + sampler = config.constructor(model) + sampler.config = config + + return sampler + + def set_samplers(): global samplers, samplers_for_img2img @@ -130,6 +130,7 @@ class VanillaStableDiffusionSampler: self.step = 0 self.eta = None self.default_eta = 0.0 + self.config = None def number_of_needed_noises(self, p): return 0 @@ -291,6 +292,7 @@ class KDiffusionSampler: self.stop_at = None self.eta = None self.default_eta = 1.0 + self.config = None def callback_state(self, d): store_latent(d["denoised"]) @@ -355,11 +357,12 @@ class KDiffusionSampler: steps = steps or p.steps if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.funcname.endswith('ka'): - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) + sigmas = p.sampler_noise_scheduler_override(steps) + elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) else: - sigmas = self.model_wrap.get_sigmas(steps) + sigmas = self.model_wrap.get_sigmas(steps) + x = x * sigmas[0] extra_params_kwargs = self.initialize(p) diff --git a/modules/shared.py b/modules/shared.py index 9e4860a2..ca2e4c74 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,7 +236,6 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "show_karras_scheduler_variants": OptionInfo(True, "Show Karras scheduling variants for select samplers. Try these variants if your K sampled images suffer from excessive noise."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From be71115b1a1201d04f0e2a11e718fb31cbd26474 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 01:09:44 +0100 Subject: Update shared.py --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index ca2e4c74..9f7c6efe 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,6 +236,7 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "show_progress_in_title": OptionInfo(False, "Show generation progress in window title."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From fec71e4de24b65b0f205a3c071b71651bbcb0dfc Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 01:35:07 +0100 Subject: Default window title progress updates on --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 9f7c6efe..5c16f025 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,7 +236,7 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "show_progress_in_title": OptionInfo(False, "Show generation progress in window title."), + "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From 0bb458f0ca06a7be27cf1a1003c536d1f06a5bd3 Mon Sep 17 00:00:00 2001 From: Milly Date: Wed, 5 Oct 2022 01:19:50 +0900 Subject: Removed duplicate image saving codes Use `modules.images.save_image()` instead. --- modules/images.py | 7 ++++--- modules/ui.py | 46 ++++++++++------------------------------------ 2 files changed, 14 insertions(+), 39 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index c2fadab9..810f1446 100644 --- a/modules/images.py +++ b/modules/images.py @@ -353,7 +353,7 @@ def get_next_sequence_number(path, basename): return result + 1 -def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""): +def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None): if short_filename or prompt is None or seed is None: file_decoration = "" elif opts.save_to_dirs: @@ -377,7 +377,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else: pnginfo = None - save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) + if save_to_dirs is None: + save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) if save_to_dirs: dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /') @@ -431,4 +432,4 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file: file.write(info + "\n") - + return fullfn diff --git a/modules/ui.py b/modules/ui.py index 9620350f..4f18126f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -35,7 +35,7 @@ import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste from modules import prompt_parser -from modules.images import apply_filename_pattern, get_next_sequence_number +from modules.images import save_image import modules.textual_inversion.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI @@ -114,20 +114,13 @@ def save_files(js_data, images, index): p = MyObject(data) path = opts.outdir_save save_to_dirs = opts.use_save_to_dirs_for_ui - - if save_to_dirs: - dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, p.seed, p.prompt) - path = os.path.join(opts.outdir_save, dirname) - - os.makedirs(path, exist_ok=True) - + extension: str = opts.samples_format + start_index = 0 if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only images = [images[index]] - infotexts = [data["infotexts"][index]] - else: - infotexts = data["infotexts"] + start_index = index with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: at_start = file.tell() == 0 @@ -135,37 +128,18 @@ def save_files(js_data, images, index): if at_start: writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]" - if file_decoration != "": - file_decoration = "-" + file_decoration.lower() - file_decoration = apply_filename_pattern(file_decoration, p, p.seed, p.prompt) - truncated = (file_decoration[:240] + '..') if len(file_decoration) > 240 else file_decoration - filename_base = truncated - extension = opts.samples_format.lower() - - basecount = get_next_sequence_number(path, "") - for i, filedata in enumerate(images): - file_number = f"{basecount+i:05}" - filename = file_number + filename_base + f".{extension}" - filepath = os.path.join(path, filename) - - + for image_index, filedata in enumerate(images, start_index): if filedata.startswith("data:image/png;base64,"): filedata = filedata[len("data:image/png;base64,"):] image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8')))) - if opts.enable_pnginfo and extension == 'png': - pnginfo = PngImagePlugin.PngInfo() - pnginfo.add_text('parameters', infotexts[i]) - image.save(filepath, pnginfo=pnginfo) - else: - image.save(filepath, quality=opts.jpeg_quality) - if opts.enable_pnginfo and extension in ("jpg", "jpeg", "webp"): - piexif.insert(piexif.dump({"Exif": { - piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(infotexts[i], encoding="unicode") - }}), filepath) + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + filename = os.path.relpath(fullfn, path) filenames.append(filename) writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) -- cgit v1.2.3 From dbc8a4d35129b08eab30776bbbaf3a2e7ac10a6c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 20:27:50 +0300 Subject: add generation parameters to images shown in web ui --- modules/processing.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index de818d5b..8faf9095 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -430,7 +430,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if opts.samples_save and not p.do_not_save_samples: images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) - infotexts.append(infotext(n, i)) + text = infotext(n, i) + infotexts.append(text) + image.info["parameters"] = text output_images.append(image) del x_samples_ddim @@ -447,7 +449,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: grid = images.image_grid(output_images, p.batch_size) if opts.return_grid: - infotexts.insert(0, infotext()) + text = infotext() + infotexts.insert(0, text) + grid.info["parameters"] = text output_images.insert(0, grid) index_of_first_image = 1 -- cgit v1.2.3 From cf7c784fcc0c84a8a4edd8d3aca4dda4c7025c43 Mon Sep 17 00:00:00 2001 From: Milly Date: Fri, 7 Oct 2022 00:19:52 +0900 Subject: Removed duplicate defined models_path Use `modules.paths.models_path` instead `modules.shared.model_path`. --- modules/shared.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 5c16f025..25bb6e6c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,11 +14,10 @@ import modules.sd_models import modules.styles import modules.devices as devices from modules import sd_samplers -from modules.paths import script_path, sd_path +from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file -model_path = os.path.join(script_path, 'models') parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) @@ -36,14 +35,14 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") -parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer')) -parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN')) -parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) -parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN')) -parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) -parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(model_path, 'ScuNET')) -parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) -parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR')) +parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) +parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) +parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) +parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN')) +parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN')) +parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) +parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) +parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.3 From 070b7d60cf5dac6387b3bfc8f3b3977b620e4fd5 Mon Sep 17 00:00:00 2001 From: Milly Date: Wed, 5 Oct 2022 02:13:09 +0900 Subject: Added styles to Processed So `[styles]` pattern can use in saving image UI. --- modules/images.py | 7 +------ modules/processing.py | 2 ++ 2 files changed, 3 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 810f1446..fa0714fd 100644 --- a/modules/images.py +++ b/modules/images.py @@ -292,12 +292,7 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[cfg]", str(p.cfg_scale)) x = x.replace("[width]", str(p.width)) x = x.replace("[height]", str(p.height)) - - #currently disabled if using the save button, will work otherwise - # if enabled it will cause a bug because styles is not included in the save_files data dictionary - if hasattr(p, "styles"): - x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False)) - + x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False)) x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) diff --git a/modules/processing.py b/modules/processing.py index 8faf9095..706dbfa8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -121,6 +121,7 @@ class Processed: self.denoising_strength = getattr(p, 'denoising_strength', None) self.extra_generation_params = p.extra_generation_params self.index_of_first_image = index_of_first_image + self.styles = p.styles self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -165,6 +166,7 @@ class Processed: "extra_generation_params": self.extra_generation_params, "index_of_first_image": self.index_of_first_image, "infotexts": self.infotexts, + "styles": self.styles, } return json.dumps(obj) -- cgit v1.2.3 From 1cc36d170ac15e7f04208df32db27af1b10c867c Mon Sep 17 00:00:00 2001 From: Milly Date: Wed, 5 Oct 2022 02:17:15 +0900 Subject: Added job_timestamp to Processed So `[job_timestamp]` pattern can use in saving image UI. --- modules/images.py | 2 +- modules/processing.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index fa0714fd..669d76af 100644 --- a/modules/images.py +++ b/modules/images.py @@ -298,7 +298,7 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) x = x.replace("[date]", datetime.date.today().isoformat()) x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S")) - x = x.replace("[job_timestamp]", shared.state.job_timestamp) + x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp)) # Apply [prompt] at last. Because it may contain any replacement word.^M if prompt is not None: diff --git a/modules/processing.py b/modules/processing.py index 706dbfa8..f773a30e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -122,6 +122,7 @@ class Processed: self.extra_generation_params = p.extra_generation_params self.index_of_first_image = index_of_first_image self.styles = p.styles + self.job_timestamp = state.job_timestamp self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -167,6 +168,7 @@ class Processed: "index_of_first_image": self.index_of_first_image, "infotexts": self.infotexts, "styles": self.styles, + "job_timestamp": self.job_timestamp, } return json.dumps(obj) -- cgit v1.2.3 From 405c8171d1acbb994084d98770bbcb97d01d9406 Mon Sep 17 00:00:00 2001 From: Milly Date: Thu, 6 Oct 2022 00:59:04 +0900 Subject: Prefer using `Processed.sd_model_hash` attribute when filename pattern --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 669d76af..29c5ee24 100644 --- a/modules/images.py +++ b/modules/images.py @@ -295,7 +295,7 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False)) x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) - x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) + x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash)) x = x.replace("[date]", datetime.date.today().isoformat()) x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S")) x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp)) -- cgit v1.2.3 From b34b25b4c941819d34f29be6c4c1ec01e64585b4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 23:27:01 +0300 Subject: karras samplers for img2img? --- modules/sd_samplers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 497df943..df17e93c 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -338,9 +338,11 @@ class KDiffusionSampler: steps, t_enc = setup_img2img_steps(p, steps) if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) + sigmas = p.sampler_noise_scheduler_override(steps) + elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) else: - sigmas = self.model_wrap.get_sigmas(steps) + sigmas = self.model_wrap.get_sigmas(steps) noise = noise * sigmas[steps - t_enc - 1] xi = x + noise -- cgit v1.2.3 From f174fb29228a04955fb951b32b0bab79e33ec2b8 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:21:49 +0300 Subject: add xformers attention --- modules/sd_hijack_optimizations.py | 39 +++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdfc..da1b76e1 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,9 @@ import math import torch from torch import einsum - +import xformers.ops +import functorch +xformers._is_functorch_available=True from ldm.util import default from einops import rearrange @@ -92,6 +94,41 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) +def _maybe_init(self, x): + """ + Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x + : B, Head, Length + """ + if self.attention_op is not None: + return + _, M, K = x.shape + try: + self.attention_op = xformers.ops.AttentionOpDispatch( + dtype=x.dtype, + device=x.device, + k=K, + attn_bias_type=type(None), + has_dropout=False, + kv_len=M, + q_len=M, + ).op + except NotImplementedError as err: + raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") + +def xformers_attention_forward(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) + v_in = self.to_v(context) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + self._maybe_init(q) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) -- cgit v1.2.3 From 2eb911b056ce6ff4434f673366782ed34f2b2f12 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:22:28 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a6fa890c..6221ed5a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -20,12 +20,17 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): - ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward + if cmd_opts.opt_split_attention: + ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + elif not cmd_opts.disable_opt_xformers_attention: + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init + ldm.modules.attention.CrossAttention.attention_op = None + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From da4ab2707b4cb0611cf181ba248a271d1937433e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:23:06 +0300 Subject: Update shared.py --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 25bb6e6c..8cc3b2fe 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -43,6 +43,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) +parser.add_argument("--disable-opt-xformers-attention", action='store_true', help="force-disables xformers attention optimization") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.3 From 35d6b231628d18d53d166c3a92fea1523e88d51e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:31:53 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6221ed5a..a006c0a3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -20,17 +20,16 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): + ldm.modules.diffusionmodules.model.nonlinearity = silu if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 if cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward elif not cmd_opts.disable_opt_xformers_attention: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init ldm.modules.attention.CrossAttention.attention_op = None - ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From 5303df24282ba06abb34a423f2967354d37d078e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 06:01:14 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a006c0a3..ddacb0ad 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -23,10 +23,10 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - if cmd_opts.opt_split_attention: + elif cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - elif not cmd_opts.disable_opt_xformers_attention: + elif not cmd_opts.disable_opt_xformers_attention and not cmd_opts.opt_split_attention: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init ldm.modules.attention.CrossAttention.attention_op = None -- cgit v1.2.3 From 5e3ff846c56dc8e1d5c76ea04a8f2f74d7da07fc Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 06:38:01 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ddacb0ad..cbdb9d3c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -26,7 +26,7 @@ def apply_optimizations(): elif cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - elif not cmd_opts.disable_opt_xformers_attention and not cmd_opts.opt_split_attention: + elif not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init ldm.modules.attention.CrossAttention.attention_op = None -- cgit v1.2.3 From bad7cb29cecac51c5c0f39afec332b007ed73133 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 10:17:52 +0300 Subject: added support for hypernetworks (???) --- modules/hypernetwork.py | 55 ++++++++++++++++++++++++++++++++++++++ modules/sd_hijack_optimizations.py | 17 ++++++++++-- modules/shared.py | 9 ++++++- 3 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 modules/hypernetwork.py (limited to 'modules') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py new file mode 100644 index 00000000..9ed1eed9 --- /dev/null +++ b/modules/hypernetwork.py @@ -0,0 +1,55 @@ +import glob +import os +import torch +from modules import devices + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + self.load_state_dict(state_dict, strict=True) + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, filename): + self.filename = filename + self.name = os.path.splitext(os.path.basename(filename))[0] + self.layers = {} + + state_dict = torch.load(filename, map_location='cpu') + for size, sd in state_dict.items(): + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + +def load_hypernetworks(path): + res = {} + + for filename in glob.iglob(path + '**/*.pt', recursive=True): + hn = Hypernetwork(filename) + res[hn.name] = hn + + return res + +def apply(self, x, context=None, mask=None, original=None): + + + if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: + if context.shape[1] == 77 and CrossAttention.noise_cond: + context = context + (torch.randn_like(context) * 0.1) + h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] + k = self.to_k(h_k(context)) + v = self.to_v(h_v(context)) + else: + k = self.to_k(context) + v = self.to_v(context) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdfc..d9cca485 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,8 @@ from torch import einsum from ldm.util import default from einops import rearrange +from modules import shared + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): @@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) + + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) + + k_in *= self.scale + del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) diff --git a/modules/shared.py b/modules/shared.py index 25bb6e6c..879d8424 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers +from modules import sd_samplers, hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -76,6 +76,12 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram config_filename = cmd_opts.ui_settings_file +hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) + + +def selected_hypernetwork(): + return hypernetworks.get(opts.sd_hypernetwork, None) + class State: interrupted = False @@ -206,6 +212,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), + "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), -- cgit v1.2.3 From d15b3ec0013c10f02f0fb80e8448bac8872a151f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 10:40:22 +0300 Subject: support loading VAE --- modules/sd_models.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 5f992064..8f794b47 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -134,6 +134,14 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + if os.path.exists(vae_file): + print(f"Loading VAE weights from: {vae_file}") + vae_ckpt = torch.load(vae_file, map_location="cpu") + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + + model.first_stage_model.load_state_dict(vae_dict) + model.sd_model_hash = sd_model_hash model.sd_model_checkpint = checkpoint_file -- cgit v1.2.3 From 97bc0b9504572d2df80598d0b694703bcd626de6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 13:22:50 +0300 Subject: do not stop working on failed hypernetwork load --- modules/hypernetwork.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 9ed1eed9..c5cf4afa 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -1,5 +1,8 @@ import glob import os +import sys +import traceback + import torch from modules import devices @@ -36,8 +39,12 @@ def load_hypernetworks(path): res = {} for filename in glob.iglob(path + '**/*.pt', recursive=True): - hn = Hypernetwork(filename) - res[hn.name] = hn + try: + hn = Hypernetwork(filename) + res[hn.name] = hn + except Exception: + print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) return res -- cgit v1.2.3 From f7c787eb7c295c27439f4fbdf78c26b8389560be Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 16:39:51 +0300 Subject: make it possible to use hypernetworks without opt split attention --- modules/hypernetwork.py | 42 ++++++++++++++++++++++++++++++++++-------- modules/sd_hijack.py | 6 ++++-- 2 files changed, 38 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index c5cf4afa..c7b86682 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -4,7 +4,12 @@ import sys import traceback import torch -from modules import devices + +from ldm.util import default +from modules import devices, shared +import torch +from torch import einsum +from einops import rearrange, repeat class HypernetworkModule(torch.nn.Module): @@ -48,15 +53,36 @@ def load_hypernetworks(path): return res -def apply(self, x, context=None, mask=None, original=None): +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) - if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: - if context.shape[1] == 77 and CrossAttention.noise_cond: - context = context + (torch.randn_like(context) * 0.1) - h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] - k = self.to_k(h_k(context)) - v = self.to_v(h_v(context)) + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k = self.to_k(hypernetwork_layers[0](context)) + v = self.to_v(hypernetwork_layers[1](context)) else: k = self.to_k(context) v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a6fa890c..d68f89cc 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,7 +8,7 @@ from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared +from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork from modules.shared import opts, device, cmd_opts import ldm.modules.attention @@ -20,6 +20,8 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): + undo_optimizations() + ldm.modules.diffusionmodules.model.nonlinearity = silu if cmd_opts.opt_split_attention_v1: @@ -30,7 +32,7 @@ def apply_optimizations(): def undo_optimizations(): - ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -- cgit v1.2.3 From 54fa613c8391e3973cca9d94cdf539061932508b Mon Sep 17 00:00:00 2001 From: Greendayle Date: Fri, 7 Oct 2022 20:37:43 +0200 Subject: loading tf only in interrogation process --- modules/deepbooru.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index fb5018a6..79dc59bd 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -1,12 +1,13 @@ import os.path from concurrent.futures import ProcessPoolExecutor -import numpy as np -import deepdanbooru as dd -import tensorflow as tf def _load_tf_and_return_tags(pil_image, threshold): + import deepdanbooru as dd + import tensorflow as tf + import numpy as np + this_folder = os.path.dirname(__file__) model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28') -- cgit v1.2.3 From fa2ea648db81f5723bb5d722f2fe0ebd7dfc319a Mon Sep 17 00:00:00 2001 From: Greendayle Date: Fri, 7 Oct 2022 20:46:38 +0200 Subject: even more powerfull fix --- modules/deepbooru.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 79dc59bd..60094336 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -60,8 +60,13 @@ def _load_tf_and_return_tags(pil_image, threshold): return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') +def subprocess_init_no_cuda(): + import os + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + def get_deepbooru_tags(pil_image, threshold=0.5): - with ProcessPoolExecutor() as executor: - f = executor.submit(_load_tf_and_return_tags, pil_image, threshold) + with ProcessPoolExecutor(initializer=subprocess_init_no_cuda) as executor: + f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, ) ret = f.result() # will rethrow any exceptions return ret \ No newline at end of file -- cgit v1.2.3 From 5f12e7efd92ad802742f96788b4be3249ad02829 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Fri, 7 Oct 2022 20:58:30 +0200 Subject: linux test --- modules/deepbooru.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 60094336..781b2249 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -1,6 +1,6 @@ import os.path from concurrent.futures import ProcessPoolExecutor - +from multiprocessing import get_context def _load_tf_and_return_tags(pil_image, threshold): @@ -66,7 +66,8 @@ def subprocess_init_no_cuda(): def get_deepbooru_tags(pil_image, threshold=0.5): - with ProcessPoolExecutor(initializer=subprocess_init_no_cuda) as executor: + context = get_context('spawn') + with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor: f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, ) ret = f.result() # will rethrow any exceptions return ret \ No newline at end of file -- cgit v1.2.3 From 12c4d5c6b5bf9dd50d0601c36af4f99b65316d58 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 23:22:22 +0300 Subject: hypernetwork training mk1 --- modules/hypernetwork.py | 88 ------------ modules/hypernetwork/hypernetwork.py | 267 +++++++++++++++++++++++++++++++++++ modules/hypernetwork/ui.py | 43 ++++++ modules/sd_hijack.py | 4 +- modules/sd_hijack_optimizations.py | 3 +- modules/shared.py | 13 +- modules/textual_inversion/ui.py | 1 - modules/ui.py | 58 +++++++- 8 files changed, 374 insertions(+), 103 deletions(-) delete mode 100644 modules/hypernetwork.py create mode 100644 modules/hypernetwork/hypernetwork.py create mode 100644 modules/hypernetwork/ui.py (limited to 'modules') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py deleted file mode 100644 index c7b86682..00000000 --- a/modules/hypernetwork.py +++ /dev/null @@ -1,88 +0,0 @@ -import glob -import os -import sys -import traceback - -import torch - -from ldm.util import default -from modules import devices, shared -import torch -from torch import einsum -from einops import rearrange, repeat - - -class HypernetworkModule(torch.nn.Module): - def __init__(self, dim, state_dict): - super().__init__() - - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) - - self.load_state_dict(state_dict, strict=True) - self.to(devices.device) - - def forward(self, x): - return x + (self.linear2(self.linear1(x))) - - -class Hypernetwork: - filename = None - name = None - - def __init__(self, filename): - self.filename = filename - self.name = os.path.splitext(os.path.basename(filename))[0] - self.layers = {} - - state_dict = torch.load(filename, map_location='cpu') - for size, sd in state_dict.items(): - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) - - -def load_hypernetworks(path): - res = {} - - for filename in glob.iglob(path + '**/*.pt', recursive=True): - try: - hn = Hypernetwork(filename) - res[hn.name] = hn - except Exception: - print(f"Error loading hypernetwork {filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - return res - - -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - hypernetwork = shared.selected_hypernetwork() - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k = self.to_k(hypernetwork_layers[0](context)) - v = self.to_v(hypernetwork_layers[1](context)) - else: - k = self.to_k(context) - v = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if mask is not None: - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py new file mode 100644 index 00000000..a3d6a47e --- /dev/null +++ b/modules/hypernetwork/hypernetwork.py @@ -0,0 +1,267 @@ +import datetime +import glob +import html +import os +import sys +import traceback +import tqdm + +import torch + +from ldm.util import default +from modules import devices, shared, processing, sd_models +import torch +from torch import einsum +from einops import rearrange, repeat +import modules.textual_inversion.dataset + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict=None): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + if state_dict is not None: + self.load_state_dict(state_dict, strict=True) + else: + self.linear1.weight.data.fill_(0.0001) + self.linear1.bias.data.fill_(0.0001) + self.linear2.weight.data.fill_(0.0001) + self.linear2.bias.data.fill_(0.0001) + + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, name=None): + self.filename = None + self.name = name + self.layers = {} + self.step = 0 + self.sd_checkpoint = None + self.sd_checkpoint_name = None + + for size in [320, 640, 768, 1280]: + self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) + + def weights(self): + res = [] + + for k, layers in self.layers.items(): + for layer in layers: + layer.train() + res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] + + return res + + def save(self, filename): + state_dict = {} + + for k, v in self.layers.items(): + state_dict[k] = (v[0].state_dict(), v[1].state_dict()) + + state_dict['step'] = self.step + state_dict['name'] = self.name + state_dict['sd_checkpoint'] = self.sd_checkpoint + state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name + + torch.save(state_dict, filename) + + def load(self, filename): + self.filename = filename + if self.name is None: + self.name = os.path.splitext(os.path.basename(filename))[0] + + state_dict = torch.load(filename, map_location='cpu') + + for size, sd in state_dict.items(): + if type(size) == int: + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + self.name = state_dict.get('name', self.name) + self.step = state_dict.get('step', 0) + self.sd_checkpoint = state_dict.get('sd_checkpoint', None) + self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) + + +def load_hypernetworks(path): + res = {} + + for filename in glob.iglob(path + '**/*.pt', recursive=True): + try: + hn = Hypernetwork() + hn.load(filename) + res[hn.name] = hn + except Exception: + print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + return res + + +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + hypernetwork_k, hypernetwork_v = hypernetwork_layers + + self.hypernetwork_k = hypernetwork_k + self.hypernetwork_v = hypernetwork_v + + context_k = hypernetwork_k(context) + context_v = hypernetwork_v(context) + else: + context_k = context + context_v = context + + k = self.to_k(context_k) + v = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): + assert hypernetwork_name, 'embedding not selected' + + shared.hypernetwork = shared.hypernetworks[hypernetwork_name] + + shared.state.textinfo = "Initializing hypernetwork training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) + + if save_hypernetwork_every > 0: + hypernetwork_dir = os.path.join(log_directory, "hypernetworks") + os.makedirs(hypernetwork_dir, exist_ok=True) + else: + hypernetwork_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + cond_model = shared.sd_model.cond_stage_model + + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + + hypernetwork = shared.hypernetworks[hypernetwork_name] + weights = hypernetwork.weights() + for weight in weights: + weight.requires_grad = True + + optimizer = torch.optim.AdamW(weights, lr=learn_rate) + + losses = torch.zeros((32,)) + + last_saved_file = "" + last_saved_image = "" + + ititial_step = hypernetwork.step or 0 + if ititial_step > steps: + return hypernetwork, filename + + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + for i, (x, text) in pbar: + hypernetwork.step = i + ititial_step + + if hypernetwork.step > steps: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([text]) + + x = x.to(devices.device) + loss = shared.sd_model(x.unsqueeze(0), c)[0] + del x + + losses[hypernetwork.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description(f"loss: {losses.mean():.7f}") + + if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') + hypernetwork.save(last_saved_file) + + if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: + last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + + preview_text = text if preview_image_prompt == "" else preview_image_prompt + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + prompt=preview_text, + steps=20, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + processed = processing.process_images(p) + image = processed.images[0] + + shared.state.current_image = image + image.save(last_saved_image) + + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = hypernetwork.step + + shared.state.textinfo = f""" +

+Loss: {losses.mean():.7f}
+Step: {hypernetwork.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + + checkpoint = sd_models.select_checkpoint() + + hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint_name = checkpoint.model_name + hypernetwork.save(filename) + + return hypernetwork, filename + + diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py new file mode 100644 index 00000000..525f978c --- /dev/null +++ b/modules/hypernetwork/ui.py @@ -0,0 +1,43 @@ +import html +import os + +import gradio as gr + +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess +from modules import sd_hijack, shared + + +def create_hypernetwork(name): + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernetwork.save(fn) + + shared.reload_hypernetworks() + shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) + + return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" + + +def train_hypernetwork(*args): + + initial_hypernetwork = shared.hypernetwork + + try: + sd_hijack.undo_optimizations() + + hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. +Hypernetwork saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + shared.hypernetwork = initial_hypernetwork + sd_hijack.apply_optimizations() + diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index d68f89cc..ec8c9d4b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,7 +8,7 @@ from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork +from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules.shared import opts, device, cmd_opts import ldm.modules.attention @@ -32,6 +32,8 @@ def apply_optimizations(): def undo_optimizations(): + from modules.hypernetwork import hypernetwork + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index d9cca485..3f32e020 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -45,8 +45,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: k_in = self.to_k(hypernetwork_layers[0](context)) diff --git a/modules/shared.py b/modules/shared.py index 879d8424..c5a893e8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, hypernetwork +from modules import sd_samplers from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -28,6 +28,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") @@ -76,11 +77,15 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram config_filename = cmd_opts.ui_settings_file -hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) +def reload_hypernetworks(): + from modules.hypernetwork import hypernetwork + hypernetworks.clear() + hypernetworks.update(hypernetwork.load_hypernetworks(cmd_opts.hypernetwork_dir)) -def selected_hypernetwork(): - return hypernetworks.get(opts.sd_hypernetwork, None) + +hypernetworks = {} +hypernetwork = None class State: diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index f19ac5e0..c57de1f9 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -22,7 +22,6 @@ def preprocess(*args): def train_embedding(*args): - try: sd_hijack.undo_optimizations() diff --git a/modules/ui.py b/modules/ui.py index 4f18126f..051908c1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -37,6 +37,7 @@ import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui +import modules.hypernetwork.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() @@ -965,6 +966,18 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create", variant='primary') + with gr.Group(): + gr.HTML(value="

Create a new hypernetwork

") + + new_hypernetwork_name = gr.Textbox(label="Name") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create", variant='primary') + with gr.Group(): gr.HTML(value="

Preprocess images

") @@ -986,6 +999,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): gr.HTML(value="

Train an embedding; must specify a directory with a set of 512x512 images

") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) learn_rate = gr.Number(label='Learning rate', value=5.0e-03) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") @@ -993,15 +1007,12 @@ def create_ui(wrap_gradio_gpu_call): steps = gr.Number(label='Max steps', value=100000, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) + preview_image_prompt = gr.Textbox(label='Preview prompt', value="") with gr.Row(): - with gr.Column(scale=2): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_training = gr.Button(value="Interrupt") - train_embedding = gr.Button(value="Train", variant='primary') + interrupt_training = gr.Button(value="Interrupt") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') + train_embedding = gr.Button(value="Train Embedding", variant='primary') with gr.Column(): progressbar = gr.HTML(elem_id="ti_progressbar") @@ -1027,6 +1038,18 @@ def create_ui(wrap_gradio_gpu_call): ] ) + create_hypernetwork.click( + fn=modules.hypernetwork.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + run_preprocess.click( fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", @@ -1062,12 +1085,33 @@ def create_ui(wrap_gradio_gpu_call): ] ) + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetwork.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + learn_rate, + dataset_directory, + log_directory, + steps, + create_image_every, + save_embedding_every, + template_file, + preview_image_prompt, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + interrupt_training.click( fn=lambda: shared.state.interrupt(), inputs=[], outputs=[], ) + def create_setting_component(key): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default -- cgit v1.2.3 From c9cc65b201679ea43c763b0d85e749d40bbc5433 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 04:09:18 +0300 Subject: switch to the proper way of calling xformers --- modules/sd_hijack_optimizations.py | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index da1b76e1..7fb4a45e 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -94,39 +94,17 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def _maybe_init(self, x): - """ - Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x - : B, Head, Length - """ - if self.attention_op is not None: - return - _, M, K = x.shape - try: - self.attention_op = xformers.ops.AttentionOpDispatch( - dtype=x.dtype, - device=x.device, - k=K, - attn_bias_type=type(None), - has_dropout=False, - kv_len=M, - q_len=M, - ).op - except NotImplementedError as err: - raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") - def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) k_in = self.to_k(context) v_in = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - self._maybe_init(q) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) def cross_attention_attnblock_forward(self, x): -- cgit v1.2.3 From b70eaeb2005a5a9593119e7fd32b8072c2a208d5 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 04:10:35 +0300 Subject: delete broken and unnecessary aliases --- modules/sd_hijack.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cbdb9d3c..0e99c319 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -21,16 +21,14 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.opt_split_attention_v1: + if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - elif not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init - ldm.modules.attention.CrossAttention.attention_op = None - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward def undo_optimizations(): -- cgit v1.2.3 From f2055cb1d4ce45d7aaacc49d8ab5bec7791a8f47 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 8 Oct 2022 01:47:02 -0400 Subject: Add hypernetwork support to split cross attention v1 * Add hypernetwork support to split_cross_attention_forward_v1 * Fix device check in esrgan_model.py to use devices.device_esrgan instead of shared.device --- modules/esrgan_model.py | 2 +- modules/sd_hijack_optimizations.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index d17e730f..28548124 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -111,7 +111,7 @@ class UpscalerESRGAN(Upscaler): print("Unable to load %s from %s" % (self.model_path, filename)) return None - pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None) + pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) pretrained_net = fix_model_layers(crt_model, pretrained_net) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index d9cca485..3351c740 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -12,13 +12,22 @@ from modules import shared def split_cross_attention_forward_v1(self, x, context=None, mask=None): h = self.heads - q = self.to_q(x) + q_in = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) + + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) for i in range(0, q.shape[0], 2): @@ -31,6 +40,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) del s2 + del q, k, v r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 -- cgit v1.2.3 From 5d54f35c583bd5a3b0ee271a862827f1ca81ef09 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 11:55:02 +0300 Subject: add xformers attnblock and hypernetwork support --- modules/sd_hijack_optimizations.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 7fb4a45e..c78d5838 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -98,8 +98,14 @@ def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) - v_in = self.to_v(context) + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) @@ -169,3 +175,13 @@ def cross_attention_attnblock_forward(self, x): h3 += x return h3 + + def xformers_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_).contiguous() + k1 = self.k(h_).contiguous() + v = self.v(h_).contiguous() + out = xformers.ops.memory_efficient_attention(q1, k1, v) + out = self.proj_out(out) + return x+out -- cgit v1.2.3 From 76a616fa6b814c681eaf6edc87eb3001b8c2b6be Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 11:55:38 +0300 Subject: Update sd_hijack_optimizations.py --- modules/sd_hijack_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index c78d5838..ee58c7e4 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -176,7 +176,7 @@ def cross_attention_attnblock_forward(self, x): return h3 - def xformers_attnblock_forward(self, x): +def xformers_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) q1 = self.q(h_).contiguous() -- cgit v1.2.3 From 91d66f5520df416db718103d460550ad495e952d Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 11:56:01 +0300 Subject: use new attnblock for xformers path --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 0e99c319..3da8c8ce 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -23,7 +23,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif cmd_opts.opt_split_attention: -- cgit v1.2.3 From 616b7218f7c469d25c138634472017a7e18e742e Mon Sep 17 00:00:00 2001 From: leko Date: Fri, 7 Oct 2022 23:09:21 +0800 Subject: fix: handles when state_dict does not exist --- modules/sd_models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 8f794b47..9409d070 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -122,7 +122,11 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): pl_sd = torch.load(checkpoint_file, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] + + if "state_dict" in pl_sd: + sd = pl_sd["state_dict"] + else: + sd = pl_sd model.load_state_dict(sd, strict=False) -- cgit v1.2.3 From 706d5944a075a6523ea7f00165d630efc085ca22 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 13:38:57 +0300 Subject: let user choose his own prompt token count limit --- modules/processing.py | 6 ++++++ modules/sd_hijack.py | 13 +++++++------ modules/shared.py | 5 +++-- 3 files changed, 16 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index f773a30e..d814d5ac 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -123,6 +123,7 @@ class Processed: self.index_of_first_image = index_of_first_image self.styles = p.styles self.job_timestamp = state.job_timestamp + self.max_prompt_tokens = opts.max_prompt_tokens self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -141,6 +142,7 @@ class Processed: self.all_subseeds = all_subseeds or [self.subseed] self.infotexts = infotexts or [info] + def js(self): obj = { "prompt": self.prompt, @@ -169,6 +171,7 @@ class Processed: "infotexts": self.infotexts, "styles": self.styles, "job_timestamp": self.job_timestamp, + "max_prompt_tokens": self.max_prompt_tokens, } return json.dumps(obj) @@ -266,6 +269,8 @@ def fix_seed(p): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size + max_tokens = getattr(p, 'max_prompt_tokens', opts.max_prompt_tokens) + generation_params = { "Steps": p.steps, "Sampler": sd_samplers.samplers[p.sampler_index].name, @@ -281,6 +286,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), + "Max tokens": (None if max_tokens == shared.vanilla_max_prompt_tokens else max_tokens) } generation_params.update(p.extra_generation_params) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index d68f89cc..340329c0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -18,7 +18,6 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward - def apply_optimizations(): undo_optimizations() @@ -83,7 +82,7 @@ class StableDiffusionModelHijack: layer.padding_mode = 'circular' if enable else 'zeros' def tokenize(self, text): - max_length = self.clip.max_length - 2 + max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, max_length @@ -94,7 +93,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.wrapped = wrapped self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer - self.max_length = wrapped.max_length self.token_mults = {} tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] @@ -116,7 +114,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length + maxlen = opts.max_prompt_tokens if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) @@ -191,7 +189,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length + maxlen = self.wrapped.max_length # you get to stay at 77 used_custom_terms = [] remade_batch_tokens = [] overflowing_words = [] @@ -268,8 +266,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + position_ids_array = [min(x, 75) for x in range(len(remade_batch_tokens[0])-1)] + [76] + position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) + tokens = torch.asarray(remade_batch_tokens).to(device) - outputs = self.wrapped.transformer(input_ids=tokens) + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise diff --git a/modules/shared.py b/modules/shared.py index 879d8424..864e772c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -118,8 +118,8 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename) interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] -# This was moved to webui.py with the other model "setup" calls. -# modules.sd_models.list_models() + +vanilla_max_prompt_tokens = 77 def realesrgan_models_names(): @@ -221,6 +221,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), + "max_prompt_tokens": OptionInfo(vanilla_max_prompt_tokens, f"Max prompt token count. Two tokens are reserved for for start and end. Default is {vanilla_max_prompt_tokens}. Setting this to a different value will result in different pictures for same seed.", gr.Number, {"precision": 0}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.3 From 786d9f63aaa4515df82eb2cf357ea92f3dae1e29 Mon Sep 17 00:00:00 2001 From: Trung Ngo Date: Tue, 4 Oct 2022 22:56:30 -0500 Subject: Add button to skip the current iteration --- modules/img2img.py | 4 ++++ modules/processing.py | 4 ++++ modules/shared.py | 5 +++++ modules/ui.py | 8 ++++++++ 4 files changed, 21 insertions(+) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index da212d72..e60b7e0f 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -32,6 +32,10 @@ def process_batch(p, input_dir, output_dir, args): for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" + if state.skipped: + state.skipped = False + state.interrupted = False + continue if state.interrupted: break diff --git a/modules/processing.py b/modules/processing.py index d814d5ac..6805039c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -355,6 +355,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: state.job_count = p.n_iter for n in range(p.n_iter): + if state.skipped: + state.skipped = False + state.interrupted = False + if state.interrupted: break diff --git a/modules/shared.py b/modules/shared.py index 864e772c..7f802bd9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -84,6 +84,7 @@ def selected_hypernetwork(): class State: + skipped = False interrupted = False job = "" job_no = 0 @@ -96,6 +97,10 @@ class State: current_image_sampling_step = 0 textinfo = None + def skip(self): + self.skipped = True + self.interrupted = True + def interrupt(self): self.interrupted = True diff --git a/modules/ui.py b/modules/ui.py index 4f18126f..e3e62fdd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -191,6 +191,7 @@ def wrap_gradio_call(func, extra_outputs=None): # last item is always HTML res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + shared.state.skipped = False shared.state.interrupted = False shared.state.job_count = 0 @@ -411,9 +412,16 @@ def create_toprow(is_img2img): with gr.Column(scale=1): with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + interrupt.click( fn=lambda: shared.state.interrupt(), inputs=[], -- cgit v1.2.3 From 00117a07efbbe8482add12262a179326541467de Mon Sep 17 00:00:00 2001 From: Trung Ngo Date: Sat, 8 Oct 2022 05:33:21 -0500 Subject: check specifically for skipped --- modules/img2img.py | 2 -- modules/processing.py | 3 +-- modules/sd_samplers.py | 4 ++-- modules/shared.py | 1 - 4 files changed, 3 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index e60b7e0f..24126774 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -34,8 +34,6 @@ def process_batch(p, input_dir, output_dir, args): state.job = f"{i+1} out of {len(images)}" if state.skipped: state.skipped = False - state.interrupted = False - continue if state.interrupted: break diff --git a/modules/processing.py b/modules/processing.py index 6805039c..3657fe69 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -357,7 +357,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: for n in range(p.n_iter): if state.skipped: state.skipped = False - state.interrupted = False if state.interrupted: break @@ -385,7 +384,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) - if state.interrupted: + if state.interrupted or state.skipped: # if we are interruped, sample returns just noise # use the image collected previously in sampler loop diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index df17e93c..13a8b322 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -106,7 +106,7 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs): seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs) for x in seq: - if state.interrupted: + if state.interrupted or state.skipped: break yield x @@ -254,7 +254,7 @@ def extended_trange(sampler, count, *args, **kwargs): seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs) for x in seq: - if state.interrupted: + if state.interrupted or state.skipped: break if sampler.stop_at is not None and x > sampler.stop_at: diff --git a/modules/shared.py b/modules/shared.py index 7f802bd9..ca462628 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -99,7 +99,6 @@ class State: def skip(self): self.skipped = True - self.interrupted = True def interrupt(self): self.interrupted = True -- cgit v1.2.3 From 4999eb2ef9b30e8c42ca7e4a94d4bbffe4d1f015 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 14:25:47 +0300 Subject: do not let user choose his own prompt token count limit --- modules/processing.py | 5 ----- modules/sd_hijack.py | 25 ++++++++++++------------- modules/shared.py | 3 --- 3 files changed, 12 insertions(+), 21 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 3657fe69..d5162ddc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -123,7 +123,6 @@ class Processed: self.index_of_first_image = index_of_first_image self.styles = p.styles self.job_timestamp = state.job_timestamp - self.max_prompt_tokens = opts.max_prompt_tokens self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -171,7 +170,6 @@ class Processed: "infotexts": self.infotexts, "styles": self.styles, "job_timestamp": self.job_timestamp, - "max_prompt_tokens": self.max_prompt_tokens, } return json.dumps(obj) @@ -269,8 +267,6 @@ def fix_seed(p): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size - max_tokens = getattr(p, 'max_prompt_tokens', opts.max_prompt_tokens) - generation_params = { "Steps": p.steps, "Sampler": sd_samplers.samplers[p.sampler_index].name, @@ -286,7 +282,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), - "Max tokens": (None if max_tokens == shared.vanilla_max_prompt_tokens else max_tokens) } generation_params.update(p.extra_generation_params) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 340329c0..2c1332c9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -36,6 +36,13 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward +def get_target_prompt_token_count(token_count): + if token_count < 75: + return 75 + + return math.ceil(token_count / 10) * 10 + + class StableDiffusionModelHijack: fixes = None comments = [] @@ -84,7 +91,7 @@ class StableDiffusionModelHijack: def tokenize(self, text): max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) - return remade_batch_tokens[0], token_count, max_length + return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): @@ -114,7 +121,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = opts.max_prompt_tokens if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) @@ -146,19 +152,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): used_custom_terms.append((embedding.name, embedding.checksum())) i += embedding_length_in_tokens - if len(remade_tokens) > maxlen - 2: - vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} - ovf = remade_tokens[maxlen - 2:] - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - token_count = len(remade_tokens) - remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + prompt_target_length = get_target_prompt_token_count(token_count) + tokens_to_add = prompt_target_length - len(remade_tokens) + 1 - multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) - multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add + multipliers = [1.0] + multipliers + [1.0] * tokens_to_add return remade_tokens, fixes, multipliers, token_count diff --git a/modules/shared.py b/modules/shared.py index ca462628..475d7e52 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -123,8 +123,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] -vanilla_max_prompt_tokens = 77 - def realesrgan_models_names(): import modules.realesrgan_model @@ -225,7 +223,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), - "max_prompt_tokens": OptionInfo(vanilla_max_prompt_tokens, f"Max prompt token count. Two tokens are reserved for for start and end. Default is {vanilla_max_prompt_tokens}. Setting this to a different value will result in different pictures for same seed.", gr.Number, {"precision": 0}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.3 From 77f4237d1c3af1756e7dab2699e3dcebad5619d6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 15:25:59 +0300 Subject: fix bugs related to variable prompt lengths --- modules/sd_hijack.py | 14 +++++++++----- modules/sd_samplers.py | 35 ++++++++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2c1332c9..7e7fde0f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -89,7 +89,6 @@ class StableDiffusionModelHijack: layer.padding_mode = 'circular' if enable else 'zeros' def tokenize(self, text): - max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) @@ -174,7 +173,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if line in cache: remade_tokens, fixes, multipliers = cache[line] else: - remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + token_count = max(current_token_count, token_count) cache[line] = (remade_tokens, fixes, multipliers) @@ -265,15 +265,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - position_ids_array = [min(x, 75) for x in range(len(remade_batch_tokens[0])-1)] + [76] + target_token_count = get_target_prompt_token_count(token_count) + 2 + + position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76] position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) - tokens = torch.asarray(remade_batch_tokens).to(device) + remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] + tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers = torch.asarray(batch_multipliers).to(device) + batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] + batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 13a8b322..eade0dbb 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -142,6 +142,16 @@ class VanillaStableDiffusionSampler: assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' cond = tensor + # for DDIM, shapes must match, we can't just process cond and uncond independently; + # filling unconditional_conditioning with repeats of the last vector to match length is + # not 100% correct but should work well enough + if unconditional_conditioning.shape[1] < cond.shape[1]: + last_vector = unconditional_conditioning[:, -1:] + last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) + unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) + elif unconditional_conditioning.shape[1] > cond.shape[1]: + unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] + if self.mask is not None: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec @@ -221,18 +231,29 @@ class CFGDenoiser(torch.nn.Module): x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - cond_in = torch.cat([tensor, uncond]) - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=cond_in) + if tensor.shape[1] == uncond.shape[1]: + cond_in = torch.cat([tensor, uncond]) + + if shared.batch_cond_uncond: + x_out = self.inner_model(x_in, sigma_in, cond=cond_in) + else: + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) else: x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): + batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size + for batch_offset in range(0, tensor.shape[0], batch_size): a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) + b = min(a + batch_size, tensor.shape[0]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b]) + + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond) - denoised_uncond = x_out[-batch_size:] + denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) for i, conds in enumerate(conds_list): -- cgit v1.2.3 From 7001bffe0247804793dfabb69ac96d832572ccd0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 15:43:25 +0300 Subject: fix AND broken for long prompts --- modules/prompt_parser.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'modules') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f00256f2..15666073 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -239,6 +239,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch) + # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes + # and won't be able to torch.stack them. So this fixes that. + token_count = max([x.shape[0] for x in tensors]) + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) -- cgit v1.2.3 From 772db721a52da374d627b60994222051f26c27a7 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Fri, 7 Oct 2022 23:02:07 +0900 Subject: fix glob path in hypernetwork.py --- modules/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index c7b86682..7f062242 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -43,7 +43,7 @@ class Hypernetwork: def load_hypernetworks(path): res = {} - for filename in glob.iglob(path + '**/*.pt', recursive=True): + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): try: hn = Hypernetwork(filename) res[hn.name] = hn -- cgit v1.2.3 From 5f85a74b00c0154bfd559dc67edfa7e30342b7c9 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Fri, 7 Oct 2022 17:48:34 -0400 Subject: fix bug where when using prompt composition, hijack_comments generated before the final AND will be dropped --- modules/processing.py | 1 + modules/sd_hijack.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index d5162ddc..8240ee27 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -313,6 +313,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: os.makedirs(p.outpath_grids, exist_ok=True) modules.sd_hijack.model_hijack.apply_circular(p.tiling) + modules.sd_hijack.model_hijack.clear_comments() comments = {} diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 7e7fde0f..ba808a39 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -88,6 +88,9 @@ class StableDiffusionModelHijack: for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: layer.padding_mode = 'circular' if enable else 'zeros' + def clear_comments(self): + self.comments = [] + def tokenize(self, text): _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) @@ -260,7 +263,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) self.hijack.fixes = hijack_fixes - self.hijack.comments = hijack_comments + self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) -- cgit v1.2.3 From 26b459a3799c5cdf71ca8ed5315a99f69c69f02c Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:20:04 +0300 Subject: default to split attention if cuda is available and xformers is not --- modules/sd_hijack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3da8c8ce..04adcf03 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -21,12 +21,12 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): + if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip or shared.xformers_available): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif cmd_opts.opt_split_attention: + elif cmd_opts.opt_split_attention or torch.cuda.is_available(): ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From ddfa9a97865c732193023a71521c5b7b53d8571b Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:20:41 +0300 Subject: add xformers_available shared variable --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 8cc3b2fe..6ed4b802 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,7 +74,7 @@ device = devices.device batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram - +xformers_available = False config_filename = cmd_opts.ui_settings_file -- cgit v1.2.3 From 69d0053583757ce2942d62de81e8b89e6be07840 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:21:40 +0300 Subject: update sd_hijack_opt to respect new env variables --- modules/sd_hijack_optimizations.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ee58c7e4..be09ec8f 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,9 +1,14 @@ import math import torch from torch import einsum -import xformers.ops -import functorch -xformers._is_functorch_available=True +try: + import xformers.ops + import functorch + xformers._is_functorch_available = True + shared.xformers_available = True +except: + print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.') + continue from ldm.util import default from einops import rearrange -- cgit v1.2.3 From 970de9ee6891ff586821d0d80dde01c2f6c681b3 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:29:43 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 04adcf03..5b30539f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -21,7 +21,7 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip or shared.xformers_available): + if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip) and shared.xformers_available: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: -- cgit v1.2.3 From 7ff1170a2e11b6f00f587407326db0b9f8f51adf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 16:33:39 +0300 Subject: emergency fix for xformers (continue + shared) --- modules/sd_hijack_optimizations.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index e43e2c7a..05023b6f 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,19 +1,19 @@ import math import torch from torch import einsum -try: - import xformers.ops - import functorch - xformers._is_functorch_available = True - shared.xformers_available = True -except: - print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.') - continue + from ldm.util import default from einops import rearrange from modules import shared +try: + import xformers.ops + import functorch + xformers._is_functorch_available = True + shared.xformers_available = True +except Exception: + print('Cannot find xformers, defaulting to split attention. Try adding --xformers commandline argument to your webui-user file if you wish to install it.') # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): -- cgit v1.2.3 From dc1117233ef8f9b25ff1ac40b158f20b70ba2fcb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 17:02:18 +0300 Subject: simplify xfrmers options: --xformers to enable and that's it --- modules/sd_hijack.py | 2 +- modules/sd_hijack_optimizations.py | 20 +++++++++++++------- modules/shared.py | 2 +- 3 files changed, 15 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5d93f7f6..91e98c16 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,7 +22,7 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip) and shared.xformers_available: + if cmd_opts.xformers and shared.xformers_available and not torch.version.hip: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 05023b6f..d23d733b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,4 +1,7 @@ import math +import sys +import traceback + import torch from torch import einsum @@ -7,13 +10,16 @@ from einops import rearrange from modules import shared -try: - import xformers.ops - import functorch - xformers._is_functorch_available = True - shared.xformers_available = True -except Exception: - print('Cannot find xformers, defaulting to split attention. Try adding --xformers commandline argument to your webui-user file if you wish to install it.') +if shared.cmd_opts.xformers: + try: + import xformers.ops + import functorch + xformers._is_functorch_available = True + shared.xformers_available = True + except Exception: + print("Cannot import xformers", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): diff --git a/modules/shared.py b/modules/shared.py index d68df751..02cb2722 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -43,7 +43,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) -parser.add_argument("--disable-opt-xformers-attention", action='store_true', help="force-disables xformers attention optimization") +parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.3 From 27032c47df9c07ac21dd5b89fa7dc247bb8705b6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 17:10:05 +0300 Subject: restore old opt_split_attention/disable_opt_split_attention logic --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 91e98c16..335a2bcf 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -27,7 +27,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif cmd_opts.opt_split_attention or torch.cuda.is_available(): + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From 4f33289d0fc5aa3a197f4a4c926d03d44f0d597e Mon Sep 17 00:00:00 2001 From: Milly Date: Sat, 8 Oct 2022 22:48:15 +0900 Subject: Fixed typo --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index e3e62fdd..ffd75f6a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -946,7 +946,7 @@ def create_ui(wrap_gradio_gpu_call): custom_name = gr.Textbox(label="Custom Name (Optional)") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3) interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") - save_as_half = gr.Checkbox(value=False, label="Safe as float16") + save_as_half = gr.Checkbox(value=False, label="Save as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') with gr.Column(variant='panel'): -- cgit v1.2.3 From cfc33f99d47d1f45af15499e5965834089d11858 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 17:28:58 +0300 Subject: why did you do this --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 335a2bcf..ed271976 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -28,7 +28,7 @@ def apply_optimizations(): elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From 017b6b8744f0771e498656ec043e12d5cc6969a7 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:27:21 +0300 Subject: check for ampere --- modules/sd_hijack.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ed271976..5e266d5e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,9 +22,10 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.xformers and shared.xformers_available and not torch.version.hip: - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + if cmd_opts.xformers and shared.xformers_available and torch.version.cuda: + if torch.cuda.get_device_capability(shared.device) == (8, 6): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): -- cgit v1.2.3 From cc0258aea7b6605be3648900063cfa96ed7c5ffa Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:44:53 +0300 Subject: check for ampere without destroying the optimizations. again. --- modules/sd_hijack.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5e266d5e..a3e374f0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,10 +22,9 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.xformers and shared.xformers_available and torch.version.cuda: - if torch.cuda.get_device_capability(shared.device) == (8, 6): - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + if cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): -- cgit v1.2.3 From a5550f0213c3f145b1c984816ebcef92c48853ee Mon Sep 17 00:00:00 2001 From: Artem Zagidulin Date: Wed, 5 Oct 2022 19:10:39 +0300 Subject: alternate prompt --- modules/prompt_parser.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 15666073..919d5d31 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -13,13 +13,14 @@ import lark schedule_parser = lark.Lark(r""" !start: (prompt | /[][():]/+)* -prompt: (emphasized | scheduled | plain | WHITESPACE)* +prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)* !emphasized: "(" prompt ")" | "(" prompt ":" prompt ")" | "[" prompt "]" scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" +alternate: "[" prompt ("|" prompt)+ "]" WHITESPACE: /\s+/ -plain: /([^\\\[\]():]|\\.)+/ +plain: /([^\\\[\]():|]|\\.)+/ %import common.SIGNED_NUMBER -> NUMBER """) @@ -59,6 +60,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): tree.children[-1] *= steps tree.children[-1] = min(steps, int(tree.children[-1])) l.append(tree.children[-1]) + def alternate(self, tree): + l.extend(range(1, steps+1)) CollectSteps().visit(tree) return sorted(set(l)) @@ -67,6 +70,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): def scheduled(self, args): before, after, _, when = args yield before or () if step <= when else after + def alternate(self, args): + yield next(args[(step - 1)%len(args)]) def start(self, args): def flatten(x): if type(x) == str: -- cgit v1.2.3 From 01f8cb44474e454903c11718e6a4f33dbde34bb8 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Sat, 8 Oct 2022 18:02:56 +0200 Subject: made deepdanbooru optional, added to readme, automatic download of deepbooru model --- modules/deepbooru.py | 20 ++++++++++---------- modules/shared.py | 1 + modules/ui.py | 19 ++++++++++++------- 3 files changed, 23 insertions(+), 17 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 781b2249..7e3c0618 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -9,16 +9,16 @@ def _load_tf_and_return_tags(pil_image, threshold): import numpy as np this_folder = os.path.dirname(__file__) - model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28') - - model_good = False - for path_candidate in [model_path, os.path.dirname(model_path)]: - if os.path.exists(os.path.join(path_candidate, 'project.json')): - model_path = path_candidate - model_good = True - if not model_good: - return ("Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/" - "deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru") + model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru')) + if not os.path.exists(os.path.join(model_path, 'project.json')): + # there is no point importing these every time + import zipfile + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", + model_path) + with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref: + zip_ref.extractall(model_path) + os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) tags = dd.project.load_tags_from_project(model_path) model = dd.project.load_model_from_project( diff --git a/modules/shared.py b/modules/shared.py index 02cb2722..c87b726e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -44,6 +44,7 @@ parser.add_argument("--scunet-models-path", type=str, help="Path to directory wi parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") +parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") diff --git a/modules/ui.py b/modules/ui.py index 30583fe9..c5c11c3c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,9 +23,10 @@ import gradio.utils import gradio.routes from modules import sd_hijack -from modules.deepbooru import get_deepbooru_tags from modules.paths import script_path from modules.shared import opts, cmd_opts +if cmd_opts.deepdanbooru: + from modules.deepbooru import get_deepbooru_tags import modules.shared as shared from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_hijack import model_hijack @@ -437,7 +438,10 @@ def create_toprow(is_img2img): with gr.Row(scale=1): if is_img2img: interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + if cmd_opts.deepdanbooru: + deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + else: + deepbooru = None else: interrogate = None deepbooru = None @@ -782,11 +786,12 @@ def create_ui(wrap_gradio_gpu_call): outputs=[img2img_prompt], ) - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) + if cmd_opts.deepdanbooru: + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) save.click( fn=wrap_gradio_call(save_files), -- cgit v1.2.3 From f9c5da159245bb1e7603b3c8b9e0703bcb1c2ff5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 19:05:19 +0300 Subject: add fallback for xformers_attnblock_forward --- modules/sd_hijack_optimizations.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index d23d733b..dba21192 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -211,6 +211,7 @@ def cross_attention_attnblock_forward(self, x): return h3 def xformers_attnblock_forward(self, x): + try: h_ = x h_ = self.norm(h_) q1 = self.q(h_).contiguous() @@ -218,4 +219,6 @@ def xformers_attnblock_forward(self, x): v = self.v(h_).contiguous() out = xformers.ops.memory_efficient_attention(q1, k1, v) out = self.proj_out(out) - return x+out + return x + out + except NotImplementedError: + return cross_attention_attnblock_forward(self, x) -- cgit v1.2.3 From 3061cdb7b610d4ba7f1ea695d9d6364b591e5bc7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 19:22:15 +0300 Subject: add --force-enable-xformers option and also add messages to console regarding cross attention optimizations --- modules/sd_hijack.py | 6 +++++- modules/shared.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a3e374f0..307cc67d 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,12 +22,16 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6): + + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)): + print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: + print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): + print("Applying cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward diff --git a/modules/shared.py b/modules/shared.py index 02cb2722..8f941226 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -44,6 +44,7 @@ parser.add_argument("--scunet-models-path", type=str, help="Path to directory wi parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") +parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.3 From 15c4278f1a18b8104e135dd82690d10cff39a2e7 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:50:01 +0100 Subject: TI preprocess wording MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I had to check the code to work out what splitting was 🤷🏿 --- modules/ui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index ffd75f6a..d52d74c6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -980,9 +980,9 @@ def create_ui(wrap_gradio_gpu_call): process_dst = gr.Textbox(label='Destination directory') with gr.Row(): - process_flip = gr.Checkbox(label='Flip') - process_split = gr.Checkbox(label='Split into two') - process_caption = gr.Checkbox(label='Add caption') + process_flip = gr.Checkbox(label='Create flipped copies') + process_split = gr.Checkbox(label='Split oversized images into two') + process_caption = gr.Checkbox(label='Use CLIP caption as filename') with gr.Row(): with gr.Column(scale=3): -- cgit v1.2.3 From b458fa48fe5734a872bca83061d702609cb52940 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:56:28 +0100 Subject: Update ui.py --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d52d74c6..b09359aa 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -982,7 +982,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') process_split = gr.Checkbox(label='Split oversized images into two') - process_caption = gr.Checkbox(label='Use CLIP caption as filename') + process_caption = gr.Checkbox(label='Use BLIP caption as filename') with gr.Row(): with gr.Column(scale=3): -- cgit v1.2.3 From 1371d7608b402d6f15c200ec2f5fde4579836a05 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 14:28:22 -0400 Subject: Added ability to ignore last n layers in FrozenCLIPEmbedder --- modules/sd_hijack.py | 11 +++++++++-- modules/shared.py | 1 + 2 files changed, 10 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 307cc67d..f12a9696 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -281,8 +281,15 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) - z = outputs.last_hidden_state + + tmp = -opts.CLIP_ignore_last_layers + if (opts.CLIP_ignore_last_layers == 0): + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) + z = outputs.last_hidden_state + else: + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) + z = outputs.hidden_states[tmp] + z = self.wrapped.transformer.text_model.final_layer_norm(z) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] diff --git a/modules/shared.py b/modules/shared.py index 8f941226..af8dc744 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -225,6 +225,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), + 'CLIP_ignore_last_layers': OptionInfo(0, "Ignore last layers of CLIP model", gr.Slider, {"minimum": 0, "maximum": 5, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.3 From e6e42f98df2c928c4f49351ad6b466387ce87d42 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 19:25:10 +0300 Subject: make --force-enable-xformers work without needing --xformers --- modules/sd_hijack_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index dba21192..c4396bb9 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -10,7 +10,7 @@ from einops import rearrange from modules import shared -if shared.cmd_opts.xformers: +if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: import xformers.ops import functorch -- cgit v1.2.3 From 3b2141c5fb6a3c2b8ab4b1e759a97ead77260129 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 22:21:15 +0300 Subject: add 'Ignore last layers of CLIP model' option as a parameter to the infotext --- modules/processing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 8240ee27..515fc91a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -123,6 +123,7 @@ class Processed: self.index_of_first_image = index_of_first_image self.styles = p.styles self.job_timestamp = state.job_timestamp + self.clip_skip = opts.CLIP_ignore_last_layers self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -141,7 +142,6 @@ class Processed: self.all_subseeds = all_subseeds or [self.subseed] self.infotexts = infotexts or [info] - def js(self): obj = { "prompt": self.prompt, @@ -170,6 +170,7 @@ class Processed: "infotexts": self.infotexts, "styles": self.styles, "job_timestamp": self.job_timestamp, + "clip_skip": self.clip_skip, } return json.dumps(obj) @@ -267,6 +268,8 @@ def fix_seed(p): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size + clip_skip = getattr(p, 'clip_skip', opts.CLIP_ignore_last_layers) + generation_params = { "Steps": p.steps, "Sampler": sd_samplers.samplers[p.sampler_index].name, @@ -282,6 +285,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), + "Clip skip": None if clip_skip==0 else clip_skip, } generation_params.update(p.extra_generation_params) -- cgit v1.2.3 From 610a7f4e1480c0ffeedb2a07dc27ae86bf03c3a8 Mon Sep 17 00:00:00 2001 From: Edouard Leurent Date: Sat, 8 Oct 2022 16:49:43 +0100 Subject: Break after finding the local directory of stable diffusion Otherwise, we may override it with one of the next two path (. or ..) if it is present there, and then the local paths of other modules (taming transformers, codeformers, etc.) wont be found in sd_path/../. Fix https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/1085 --- modules/paths.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/paths.py b/modules/paths.py index 606f7d66..0519caa0 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -12,6 +12,7 @@ possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), for possible_sd_path in possible_sd_paths: if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): sd_path = os.path.abspath(possible_sd_path) + break assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) -- cgit v1.2.3 From 432782163ae53e605470bcefc9a6f796c4556912 Mon Sep 17 00:00:00 2001 From: Aidan Holland Date: Sat, 8 Oct 2022 15:12:24 -0400 Subject: chore: Fix typos --- modules/interrogate.py | 4 ++-- modules/processing.py | 2 +- modules/scunet_model_arch.py | 4 ++-- modules/sd_models.py | 4 ++-- modules/sd_samplers.py | 4 ++-- modules/shared.py | 6 +++--- modules/swinir_model_arch.py | 2 +- modules/ui.py | 4 ++-- 8 files changed, 15 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index eed87144..635e266e 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -140,11 +140,11 @@ class InterrogateModels: res = caption - cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) + clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext with torch.no_grad(), precision_scope("cuda"): - image_features = self.clip_model.encode_image(cilp_image).type(self.dtype) + image_features = self.clip_model.encode_image(clip_image).type(self.dtype) image_features /= image_features.norm(dim=-1, keepdim=True) diff --git a/modules/processing.py b/modules/processing.py index 515fc91a..31220881 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -386,7 +386,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.interrupted or state.skipped: - # if we are interruped, sample returns just noise + # if we are interrupted, sample returns just noise # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent diff --git a/modules/scunet_model_arch.py b/modules/scunet_model_arch.py index 972a2639..43ca8d36 100644 --- a/modules/scunet_model_arch.py +++ b/modules/scunet_model_arch.py @@ -40,7 +40,7 @@ class WMSA(nn.Module): Returns: attn_mask: should be (1 1 w p p), """ - # supporting sqaure. + # supporting square. attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) if self.type == 'W': return attn_mask @@ -65,7 +65,7 @@ class WMSA(nn.Module): x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) h_windows = x.size(1) w_windows = x.size(2) - # sqaure validation + # square validation # assert h_windows == w_windows x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) diff --git a/modules/sd_models.py b/modules/sd_models.py index 9409d070..a09866ce 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -147,7 +147,7 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): model.first_stage_model.load_state_dict(vae_dict) model.sd_model_hash = sd_model_hash - model.sd_model_checkpint = checkpoint_file + model.sd_model_checkpoint = checkpoint_file def load_model(): @@ -175,7 +175,7 @@ def reload_model_weights(sd_model, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - if sd_model.sd_model_checkpint == checkpoint_info.filename: + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index eade0dbb..6e743f7e 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -181,7 +181,7 @@ class VanillaStableDiffusionSampler: self.initialize(p) - # existing code fails with cetain step counts, like 9 + # existing code fails with certain step counts, like 9 try: self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) except Exception: @@ -204,7 +204,7 @@ class VanillaStableDiffusionSampler: steps = steps or p.steps - # existing code fails with cetin step counts, like 9 + # existing code fails with certain step counts, like 9 try: samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta) except Exception: diff --git a/modules/shared.py b/modules/shared.py index af8dc744..2dc092d6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -141,9 +141,9 @@ class OptionInfo: self.section = None -def options_section(section_identifer, options_dict): +def options_section(section_identifier, options_dict): for k, v in options_dict.items(): - v.section = section_identifer + v.section = section_identifier return options_dict @@ -246,7 +246,7 @@ options_templates.update(options_section(('ui', "User interface"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), - "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), })) diff --git a/modules/swinir_model_arch.py b/modules/swinir_model_arch.py index 461fb354..863f42db 100644 --- a/modules/swinir_model_arch.py +++ b/modules/swinir_model_arch.py @@ -166,7 +166,7 @@ class SwinTransformerBlock(nn.Module): Args: dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. + input_resolution (tuple[int]): Input resolution. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. diff --git a/modules/ui.py b/modules/ui.py index b09359aa..b51af121 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -38,7 +38,7 @@ from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') @@ -102,7 +102,7 @@ def save_files(js_data, images, index): import csv filenames = [] - #quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it class MyObject: def __init__(self, d=None): if d is not None: -- cgit v1.2.3 From 050a6a798cec90ae2f881c2ddd3f0221e69907dc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 23:26:48 +0300 Subject: support loading .yaml config with same name as model support EMA weights in processing (????) --- modules/processing.py | 2 +- modules/sd_models.py | 30 +++++++++++++++++++++++------- 2 files changed, 24 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 31220881..4fea6d56 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -347,7 +347,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] - with torch.no_grad(): + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(all_prompts, all_seeds, all_subseeds) diff --git a/modules/sd_models.py b/modules/sd_models.py index a09866ce..cb3982b1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ from modules.paths import models_path model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) checkpoints_list = {} try: @@ -63,14 +63,20 @@ def list_models(): if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) for filename in model_list: h = model_hash(filename) title, short_model_name = modeltitle(filename, h) - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) + + basename, _ = os.path.splitext(filename) + config = basename + ".yaml" + if not os.path.exists(config): + config = shared.cmd_opts.config + + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) def get_closet_checkpoint_match(searchString): @@ -116,7 +122,10 @@ def select_checkpoint(): return checkpoint_info -def load_model_weights(model, checkpoint_file, sd_model_hash): +def load_model_weights(model, checkpoint_info): + checkpoint_file = checkpoint_info.filename + sd_model_hash = checkpoint_info.hash + print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location="cpu") @@ -148,15 +157,19 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file + model.sd_checkpoint_info = checkpoint_info def load_model(): from modules import lowvram, sd_hijack checkpoint_info = select_checkpoint() - sd_config = OmegaConf.load(shared.cmd_opts.config) + if checkpoint_info.config != shared.cmd_opts.config: + print(f"Loading config from: {shared.cmd_opts.config}") + + sd_config = OmegaConf.load(checkpoint_info.config) sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -178,6 +191,9 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return + if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + return load_model() + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: @@ -185,7 +201,7 @@ def reload_model_weights(sd_model, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) + load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model) -- cgit v1.2.3 From 5841990b0df04906da7321beef6f7f7902b7d57b Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 05:38:38 +0100 Subject: Update textual_inversion.py --- modules/textual_inversion/textual_inversion.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index cd9f3498..f6316020 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,6 +7,9 @@ import tqdm import html import datetime +from PIL import Image, PngImagePlugin +import base64 +from io import BytesIO from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset @@ -80,7 +83,15 @@ class EmbeddingDatabase: def process_file(path, filename): name = os.path.splitext(filename)[0] - data = torch.load(path, map_location="cpu") + data = [] + + if filename.upper().endswith('.PNG'): + embed_image = Image.open(path) + if 'sd-embedding' in embed_image.text: + embeddingData = base64.b64decode(embed_image.text['sd-embedding']) + data = torch.load(BytesIO(embeddingData), map_location="cpu") + else: + data = torch.load(path, map_location="cpu") # textual inversion embeddings if 'string_to_param' in data: @@ -156,7 +167,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -244,7 +255,15 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, image = processed.images[0] shared.state.current_image = image - image.save(last_saved_image) + + if save_image_with_stored_embedding: + info = PngImagePlugin.PngInfo() + info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read())) + image.save(last_saved_image, "PNG", pnginfo=info) + else: + image.save(last_saved_image) + + last_saved_image += f", prompt: {text}" -- cgit v1.2.3 From cd8673bd9b2e59bddefee8d307340d643695fe11 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 05:40:57 +0100 Subject: add embed embedding to ui --- modules/ui.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index b51af121..a5983204 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1001,7 +1001,8 @@ def create_ui(wrap_gradio_gpu_call): steps = gr.Number(label='Max steps', value=100000, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) - + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) + with gr.Row(): with gr.Column(scale=2): gr.HTML(value="") @@ -1063,6 +1064,7 @@ def create_ui(wrap_gradio_gpu_call): create_image_every, save_embedding_every, template_file, + save_image_with_stored_embedding, ], outputs=[ ti_output, -- cgit v1.2.3 From c77c89cc83c618472ad352cf8a28fde28c3a1377 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 10:23:31 +0300 Subject: make main model loading and model merger use the same code --- modules/extras.py | 6 +++--- modules/sd_models.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 1d9e64e5..ef6e6de7 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -169,9 +169,9 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int print(f"Loading {secondary_model_info.filename}...") secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') - - theta_0 = primary_model['state_dict'] - theta_1 = secondary_model['state_dict'] + + theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) + theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model) theta_funcs = { "Weighted Sum": weighted_sum, diff --git a/modules/sd_models.py b/modules/sd_models.py index cb3982b1..18fb8c2e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -122,6 +122,13 @@ def select_checkpoint(): return checkpoint_info +def get_state_dict_from_checkpoint(pl_sd): + if "state_dict" in pl_sd: + return pl_sd["state_dict"] + + return pl_sd + + def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash @@ -131,11 +138,8 @@ def load_model_weights(model, checkpoint_info): pl_sd = torch.load(checkpoint_file, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") - - if "state_dict" in pl_sd: - sd = pl_sd["state_dict"] - else: - sd = pl_sd + + sd = get_state_dict_from_checkpoint(pl_sd) model.load_state_dict(sd, strict=False) -- cgit v1.2.3 From 4e569fd888f8e3c5632a072d51abbb6e4d17abd6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 10:31:47 +0300 Subject: fixed incorrect message about loading config; thanks anon! --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 18fb8c2e..2101b18d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -169,7 +169,7 @@ def load_model(): checkpoint_info = select_checkpoint() if checkpoint_info.config != shared.cmd_opts.config: - print(f"Loading config from: {shared.cmd_opts.config}") + print(f"Loading config from: {checkpoint_info.config}") sd_config = OmegaConf.load(checkpoint_info.config) sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.3 From 5ab7e88d9b0bb0125af9f7237242a00a93360ce5 Mon Sep 17 00:00:00 2001 From: aoirusann <82883326+aoirusann@users.noreply.github.com> Date: Sat, 8 Oct 2022 13:09:29 +0800 Subject: Add `Download` & `Download as zip` --- modules/ui.py | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index b51af121..fe7f10a7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -98,9 +98,10 @@ def send_gradio_gallery_to_image(x): return image_from_url_text(x[0]) -def save_files(js_data, images, index): +def save_files(js_data, images, do_make_zip, index): import csv filenames = [] + fullfns = [] #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it class MyObject: @@ -141,10 +142,22 @@ def save_files(js_data, images, index): filename = os.path.relpath(fullfn, path) filenames.append(filename) + fullfns.append(fullfn) writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - return '', '', plaintext_to_html(f"Saved: {filenames[0]}") + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return fullfns, '', '', plaintext_to_html(f"Saved: {filenames[0]}") def wrap_gradio_call(func, extra_outputs=None): @@ -521,6 +534,12 @@ def create_ui(wrap_gradio_gpu_call): button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id) + with gr.Row(): + do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) + + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False) + with gr.Group(): html_info = gr.HTML() generation_info = gr.Textbox(visible=False) @@ -570,13 +589,15 @@ def create_ui(wrap_gradio_gpu_call): save.click( fn=wrap_gradio_call(save_files), - _js="(x, y, z) => [x, y, selected_gallery_index()]", + _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", inputs=[ generation_info, txt2img_gallery, + do_make_zip, html_info, ], outputs=[ + download_files, html_info, html_info, html_info, @@ -701,6 +722,12 @@ def create_ui(wrap_gradio_gpu_call): button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id) + with gr.Row(): + do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) + + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False) + with gr.Group(): html_info = gr.HTML() generation_info = gr.Textbox(visible=False) @@ -776,13 +803,15 @@ def create_ui(wrap_gradio_gpu_call): save.click( fn=wrap_gradio_call(save_files), - _js="(x, y, z) => [x, y, selected_gallery_index()]", + _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", inputs=[ generation_info, img2img_gallery, - html_info + do_make_zip, + html_info, ], outputs=[ + download_files, html_info, html_info, html_info, -- cgit v1.2.3 From 14192c5b207b16b1ec7a4c9c4ea538d1a6811a4d Mon Sep 17 00:00:00 2001 From: aoirusann Date: Sun, 9 Oct 2022 13:01:10 +0800 Subject: Support `Download` for txt files. --- modules/images.py | 39 +++++++++++++++++++++++++++++++++++++-- modules/ui.py | 5 ++++- 2 files changed, 41 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 29c5ee24..c0a90676 100644 --- a/modules/images.py +++ b/modules/images.py @@ -349,6 +349,38 @@ def get_next_sequence_number(path, basename): def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None): + '''Save an image. + + Args: + image (`PIL.Image`): + The image to be saved. + path (`str`): + The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory. + basename (`str`): + The base filename which will be applied to `filename pattern`. + seed, prompt, short_filename, + extension (`str`): + Image file extension, default is `png`. + pngsectionname (`str`): + Specify the name of the section which `info` will be saved in. + info (`str` or `PngImagePlugin.iTXt`): + PNG info chunks. + existing_info (`dict`): + Additional PNG info. `existing_info == {pngsectionname: info, ...}` + no_prompt: + TODO I don't know its meaning. + p (`StableDiffusionProcessing`) + forced_filename (`str`): + If specified, `basename` and filename pattern will be ignored. + save_to_dirs (bool): + If true, the image will be saved into a subdirectory of `path`. + + Returns: (fullfn, txt_fullfn) + fullfn (`str`): + The full path of the saved imaged. + txt_fullfn (`str` or None): + If a text file is saved for this image, this will be its full path. Otherwise None. + ''' if short_filename or prompt is None or seed is None: file_decoration = "" elif opts.save_to_dirs: @@ -424,7 +456,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg") if opts.save_txt and info is not None: - with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file: + txt_fullfn = f"{fullfn_without_extension}.txt" + with open(txt_fullfn, "w", encoding="utf8") as file: file.write(info + "\n") + else: + txt_fullfn = None - return fullfn + return fullfn, txt_fullfn diff --git a/modules/ui.py b/modules/ui.py index fe7f10a7..debd8873 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -138,11 +138,14 @@ def save_files(js_data, images, do_make_zip, index): is_grid = image_index < p.index_of_first_image i = 0 if is_grid else (image_index - p.index_of_first_image) - fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) filenames.append(filename) fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) -- cgit v1.2.3 From 122d42687b97ec4df4c2a8c335d2de385cd1f1a1 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 22:37:35 -0400 Subject: Fix VRAM Issue by only loading in hypernetwork when selected in settings --- modules/hypernetwork.py | 23 +++++++++++++++-------- modules/sd_hijack_optimizations.py | 6 +++--- modules/shared.py | 7 ++----- 3 files changed, 20 insertions(+), 16 deletions(-) (limited to 'modules') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 7f062242..19f1c227 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -40,18 +40,25 @@ class Hypernetwork: self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) -def load_hypernetworks(path): +def list_hypernetworks(path): res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res + + +def load_hypernetwork(filename): + print(f"Loading hypernetwork {filename}") + path = shared.hypernetworks.get(filename, None) + if (path is not None): try: - hn = Hypernetwork(filename) - res[hn.name] = hn + shared.loaded_hypernetwork = Hypernetwork(path) except Exception: - print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - - return res + else: + shared.loaded_hypernetwork = None def attention_CrossAttention_forward(self, x, context=None, mask=None): @@ -60,7 +67,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index c4396bb9..634fb4b2 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -28,7 +28,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: @@ -68,7 +68,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: @@ -132,7 +132,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: k_in = self.to_k(hypernetwork_layers[0](context)) diff --git a/modules/shared.py b/modules/shared.py index b2c76a32..9dce6cb7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -79,11 +79,8 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False config_filename = cmd_opts.ui_settings_file -hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) - - -def selected_hypernetwork(): - return hypernetworks.get(opts.sd_hypernetwork, None) +hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks')) +loaded_hypernetwork = None class State: -- cgit v1.2.3 From 03e570886f430f39020e504aba057a95f2e62484 Mon Sep 17 00:00:00 2001 From: frostydad <64224601+Cyberes@users.noreply.github.com> Date: Sat, 8 Oct 2022 18:13:13 -0600 Subject: Fix incorrect sampler name in output --- modules/processing.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 4fea6d56..6b8664a0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,3 +1,4 @@ + import json import math import os @@ -46,6 +47,12 @@ def apply_color_correction(correction, image): return image +def get_correct_sampler(p): + if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img): + return sd_samplers.samplers + elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): + return sd_samplers.samplers_for_img2img + class StableDiffusionProcessing: def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): self.sd_model = sd_model @@ -272,7 +279,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params = { "Steps": p.steps, - "Sampler": sd_samplers.samplers[p.sampler_index].name, + "Sampler": get_correct_sampler(p)[p.sampler_index].name, "CFG scale": p.cfg_scale, "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), -- cgit v1.2.3 From ef93acdc731b7a2b3c13651b6de1bce58af989d4 Mon Sep 17 00:00:00 2001 From: frostydad <64224601+Cyberes@users.noreply.github.com> Date: Sat, 8 Oct 2022 18:15:35 -0600 Subject: remove line break --- modules/processing.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 6b8664a0..7fa1144e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,4 +1,3 @@ - import json import math import os -- cgit v1.2.3 From 1ffeb42d38d9276dc28918189d32f60d593a162c Mon Sep 17 00:00:00 2001 From: Nicolas Noullet Date: Sun, 9 Oct 2022 00:18:45 +0200 Subject: Fix typo --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 9dce6cb7..dffa0094 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -238,7 +238,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), + "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), -- cgit v1.2.3 From e2930f9821c197da94e208b5ae73711002844efc Mon Sep 17 00:00:00 2001 From: Tony Beeman Date: Fri, 7 Oct 2022 17:46:39 -0700 Subject: Fix for Prompts_from_file showing extra textbox. --- modules/scripts.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 45230f9a..d8f87927 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,4 +1,5 @@ import os +from pydoc import visiblename import sys import traceback @@ -31,6 +32,15 @@ class Script: def show(self, is_img2img): return True + + # Called when the ui for this script has been shown. + # Useful for hiding some controls, since the scripts module sets visibility to + # everything to true. The parameters will be the parameters returned by the ui method + # The return value should be gradio updates, similar to what you would return + # from a Gradio event handler. + def on_show(self, *args): + return [ui.gr_show(True)] * len(args) + # This is where the additional processing is implemented. The parameters include # self, the model object "p" (a StableDiffusionProcessing class, see # processing.py), and the parameters returned by the ui method. @@ -125,20 +135,32 @@ class ScriptRunner: inputs += controls script.args_to = len(inputs) - def select_script(script_index): + def select_script(*args): + script_index = args[0] + on_show_updates = [] if 0 < script_index <= len(self.scripts): script = self.scripts[script_index-1] args_from = script.args_from args_to = script.args_to + script_args = args[args_from:args_to] + on_show_updates = wrap_call(script.on_show, script.filename, "on_show", *script_args) else: args_from = 0 args_to = 0 - return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] + ret = [ ui.gr_show(True)] # always show the dropdown + for i in range(1, len(inputs)): + if (args_from <= i < args_to): + ret.append( on_show_updates[i - args_from] ) + else: + ret.append(ui.gr_show(False)) + return ret + + # return [ui.gr_show(True if (i == 0) else on_show_updates[i - args_from] if args_from <= i < args_to else False) for i in range(len(inputs))] dropdown.change( fn=select_script, - inputs=[dropdown], + inputs=inputs, outputs=inputs ) @@ -198,4 +220,4 @@ def reload_scripts(basedir): load_scripts(basedir) scripts_txt2img = ScriptRunner() - scripts_img2img = ScriptRunner() + scripts_img2img = ScriptRunner() \ No newline at end of file -- cgit v1.2.3 From 86cb16886f8f48169cee4658ad0c5e5443beed2a Mon Sep 17 00:00:00 2001 From: Tony Beeman Date: Fri, 7 Oct 2022 23:51:50 -0700 Subject: Pull Request Code Review Fixes --- modules/scripts.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index d8f87927..8dfd4de9 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,5 +1,4 @@ import os -from pydoc import visiblename import sys import traceback -- cgit v1.2.3 From cbf6dad02d04d98e5a2d5e870777ab99b5796b2d Mon Sep 17 00:00:00 2001 From: Tony Beeman Date: Sat, 8 Oct 2022 10:40:30 -0700 Subject: Handle case where on_show returns the wrong number of arguments --- modules/scripts.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 8dfd4de9..7d89979d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -143,6 +143,8 @@ class ScriptRunner: args_to = script.args_to script_args = args[args_from:args_to] on_show_updates = wrap_call(script.on_show, script.filename, "on_show", *script_args) + if (len(on_show_updates) != (args_to - args_from)): + print("Error in custom script (" + script.filename + "): on_show() method should return the same number of arguments as ui().", file=sys.stderr) else: args_from = 0 args_to = 0 @@ -150,13 +152,14 @@ class ScriptRunner: ret = [ ui.gr_show(True)] # always show the dropdown for i in range(1, len(inputs)): if (args_from <= i < args_to): - ret.append( on_show_updates[i - args_from] ) + if (i - args_from) < len(on_show_updates): + ret.append( on_show_updates[i - args_from] ) + else: + ret.append(ui.gr_show(True)) else: ret.append(ui.gr_show(False)) return ret - # return [ui.gr_show(True if (i == 0) else on_show_updates[i - args_from] if args_from <= i < args_to else False) for i in range(len(inputs))] - dropdown.change( fn=select_script, inputs=inputs, -- cgit v1.2.3 From ab4fe4f44c3d2675a351269fe2ff1ddeac557aa6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 11:59:41 +0300 Subject: hide filenames for save button by default --- modules/ui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 8071b1cb..e1ab2665 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -162,7 +162,7 @@ def save_files(js_data, images, do_make_zip, index): zip_file.writestr(filenames[i], f.read()) fullfns.insert(0, zip_filepath) - return fullfns, '', '', plaintext_to_html(f"Saved: {filenames[0]}") + return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") def wrap_gradio_call(func, extra_outputs=None): @@ -553,7 +553,7 @@ def create_ui(wrap_gradio_gpu_call): do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False) + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) with gr.Group(): html_info = gr.HTML() @@ -741,7 +741,7 @@ def create_ui(wrap_gradio_gpu_call): do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False) + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) with gr.Group(): html_info = gr.HTML() -- cgit v1.2.3 From 0241d811d23427b99f6b1eda1540bdf8d87963d5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 12:04:44 +0300 Subject: Revert "Fix for Prompts_from_file showing extra textbox." This reverts commit e2930f9821c197da94e208b5ae73711002844efc. --- modules/scripts.py | 32 ++++---------------------------- 1 file changed, 4 insertions(+), 28 deletions(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 7d89979d..45230f9a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -31,15 +31,6 @@ class Script: def show(self, is_img2img): return True - - # Called when the ui for this script has been shown. - # Useful for hiding some controls, since the scripts module sets visibility to - # everything to true. The parameters will be the parameters returned by the ui method - # The return value should be gradio updates, similar to what you would return - # from a Gradio event handler. - def on_show(self, *args): - return [ui.gr_show(True)] * len(args) - # This is where the additional processing is implemented. The parameters include # self, the model object "p" (a StableDiffusionProcessing class, see # processing.py), and the parameters returned by the ui method. @@ -134,35 +125,20 @@ class ScriptRunner: inputs += controls script.args_to = len(inputs) - def select_script(*args): - script_index = args[0] - on_show_updates = [] + def select_script(script_index): if 0 < script_index <= len(self.scripts): script = self.scripts[script_index-1] args_from = script.args_from args_to = script.args_to - script_args = args[args_from:args_to] - on_show_updates = wrap_call(script.on_show, script.filename, "on_show", *script_args) - if (len(on_show_updates) != (args_to - args_from)): - print("Error in custom script (" + script.filename + "): on_show() method should return the same number of arguments as ui().", file=sys.stderr) else: args_from = 0 args_to = 0 - ret = [ ui.gr_show(True)] # always show the dropdown - for i in range(1, len(inputs)): - if (args_from <= i < args_to): - if (i - args_from) < len(on_show_updates): - ret.append( on_show_updates[i - args_from] ) - else: - ret.append(ui.gr_show(True)) - else: - ret.append(ui.gr_show(False)) - return ret + return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] dropdown.change( fn=select_script, - inputs=inputs, + inputs=[dropdown], outputs=inputs ) @@ -222,4 +198,4 @@ def reload_scripts(basedir): load_scripts(basedir) scripts_txt2img = ScriptRunner() - scripts_img2img = ScriptRunner() \ No newline at end of file + scripts_img2img = ScriptRunner() -- cgit v1.2.3 From 6f6798ddabe10d320fe8ea05edf0fdcef0c51a8e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 12:33:37 +0300 Subject: prevent a possible code execution error (thanks, RyotaK) --- modules/ui.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index e1ab2665..dad509f3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1153,6 +1153,15 @@ def create_ui(wrap_gradio_gpu_call): component_dict = {} def open_folder(f): + if not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + if not shared.cmd_opts.hide_ui_dir_config: path = os.path.normpath(f) if platform.system() == "Windows": -- cgit v1.2.3 From 0609ce06c0778536cb368ac3867292f87c6d9fc7 Mon Sep 17 00:00:00 2001 From: Milly Date: Fri, 7 Oct 2022 03:36:08 +0900 Subject: Removed duplicate definition model_path --- modules/bsrgan_model.py | 2 -- modules/esrgan_model.py | 2 -- modules/ldsr_model.py | 2 -- modules/realesrgan_model.py | 2 -- modules/scunet_model.py | 2 -- modules/swinir_model.py | 2 -- modules/upscaler.py | 7 ++++--- 7 files changed, 4 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py index 3bd80791..737e1a76 100644 --- a/modules/bsrgan_model.py +++ b/modules/bsrgan_model.py @@ -10,13 +10,11 @@ from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader from modules.bsrgan_model_arch import RRDBNet -from modules.paths import models_path class UpscalerBSRGAN(modules.upscaler.Upscaler): def __init__(self, dirname): self.name = "BSRGAN" - self.model_path = os.path.join(models_path, self.name) self.model_name = "BSRGAN 4x" self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" self.user_path = dirname diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 28548124..3970e6e4 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -7,7 +7,6 @@ from basicsr.utils.download_util import load_file_from_url import modules.esrgam_model_arch as arch from modules import shared, modelloader, images, devices -from modules.paths import models_path from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts @@ -76,7 +75,6 @@ class UpscalerESRGAN(Upscaler): self.model_name = "ESRGAN_4x" self.scalers = [] self.user_path = dirname - self.model_path = os.path.join(models_path, self.name) super().__init__() model_paths = self.find_models(ext_filter=[".pt", ".pth"]) scalers = [] diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py index 1c1070fc..8c4db44a 100644 --- a/modules/ldsr_model.py +++ b/modules/ldsr_model.py @@ -7,13 +7,11 @@ from basicsr.utils.download_util import load_file_from_url from modules.upscaler import Upscaler, UpscalerData from modules.ldsr_model_arch import LDSR from modules import shared -from modules.paths import models_path class UpscalerLDSR(Upscaler): def __init__(self, user_path): self.name = "LDSR" - self.model_path = os.path.join(models_path, self.name) self.user_path = user_path self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index dc0123e0..3ac0b97a 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -8,14 +8,12 @@ from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer from modules.upscaler import Upscaler, UpscalerData -from modules.paths import models_path from modules.shared import cmd_opts, opts class UpscalerRealESRGAN(Upscaler): def __init__(self, path): self.name = "RealESRGAN" - self.model_path = os.path.join(models_path, self.name) self.user_path = path super().__init__() try: diff --git a/modules/scunet_model.py b/modules/scunet_model.py index fb64b740..36a996bf 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -9,14 +9,12 @@ from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader -from modules.paths import models_path from modules.scunet_model_arch import SCUNet as net class UpscalerScuNET(modules.upscaler.Upscaler): def __init__(self, dirname): self.name = "ScuNET" - self.model_path = os.path.join(models_path, self.name) self.model_name = "ScuNET GAN" self.model_name2 = "ScuNET PSNR" self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" diff --git a/modules/swinir_model.py b/modules/swinir_model.py index 9bd454c6..fbd11f84 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -8,7 +8,6 @@ from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader -from modules.paths import models_path from modules.shared import cmd_opts, opts, device from modules.swinir_model_arch import SwinIR as net from modules.upscaler import Upscaler, UpscalerData @@ -25,7 +24,6 @@ class UpscalerSwinIR(Upscaler): "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ "-L_x4_GAN.pth " self.model_name = "SwinIR 4x" - self.model_path = os.path.join(models_path, self.name) self.user_path = dirname super().__init__() scalers = [] diff --git a/modules/upscaler.py b/modules/upscaler.py index d9d7c5e2..34672be7 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -36,10 +36,11 @@ class Upscaler: self.half = not modules.shared.cmd_opts.no_half self.pre_pad = 0 self.mod_scale = None - if self.name is not None and create_dirs: + + if self.model_path is not None and self.name: self.model_path = os.path.join(models_path, self.name) - if not os.path.exists(self.model_path): - os.makedirs(self.model_path) + if self.model_path and create_dirs: + os.makedirs(self.model_path, exist_ok=True) try: import cv2 -- cgit v1.2.3 From bd833409ac7b8337040d521f6b65ced51e1b2ea8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 13:10:15 +0300 Subject: additional changes for saving pnginfo for #1803 --- modules/extras.py | 4 ++++ modules/processing.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index ef6e6de7..39dd3806 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -98,6 +98,10 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=image_name if opts.use_original_name_batch else None) + if opts.enable_pnginfo: + image.info = existing_pnginfo + image.info["extras"] = info + outputs.append(image) devices.torch_gc() diff --git a/modules/processing.py b/modules/processing.py index 7fa1144e..2c991317 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -451,7 +451,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: text = infotext(n, i) infotexts.append(text) - image.info["parameters"] = text + if opts.enable_pnginfo: + image.info["parameters"] = text output_images.append(image) del x_samples_ddim @@ -470,7 +471,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if opts.return_grid: text = infotext() infotexts.insert(0, text) - grid.info["parameters"] = text + if opts.enable_pnginfo: + grid.info["parameters"] = text output_images.insert(0, grid) index_of_first_image = 1 -- cgit v1.2.3 From f4578b343ded3b8ccd1879ea0c0b3cdadfcc3a5f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 13:23:30 +0300 Subject: fix model switching not working properly if there is a different yaml config --- modules/sd_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 2101b18d..d0c74dd8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -196,7 +196,8 @@ def reload_model_weights(sd_model, info=None): return if sd_model.sd_checkpoint_info.config != checkpoint_info.config: - return load_model() + shared.sd_model = load_model() + return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() -- cgit v1.2.3 From 77a719648db515f10136e8b8483d5b16bda2eaeb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 13:48:04 +0300 Subject: fix logic error in #1832 --- modules/upscaler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/upscaler.py b/modules/upscaler.py index 34672be7..6ab2fb40 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -37,7 +37,7 @@ class Upscaler: self.pre_pad = 0 self.mod_scale = None - if self.model_path is not None and self.name: + if self.model_path is None and self.name: self.model_path = os.path.join(models_path, self.name) if self.model_path and create_dirs: os.makedirs(self.model_path, exist_ok=True) -- cgit v1.2.3 From ad4de819c43997f2666b5bad95301f5c37f9018e Mon Sep 17 00:00:00 2001 From: victorca25 Date: Sun, 9 Oct 2022 13:02:12 +0200 Subject: update ESRGAN architecture and model to support all ESRGAN models in the DB, BSRGAN and real-ESRGAN models --- modules/bsrgan_model.py | 76 ------- modules/bsrgan_model_arch.py | 102 ---------- modules/esrgam_model_arch.py | 80 -------- modules/esrgan_model.py | 190 ++++++++++++------ modules/esrgan_model_arch.py | 463 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 591 insertions(+), 320 deletions(-) delete mode 100644 modules/bsrgan_model.py delete mode 100644 modules/bsrgan_model_arch.py delete mode 100644 modules/esrgam_model_arch.py create mode 100644 modules/esrgan_model_arch.py (limited to 'modules') diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py deleted file mode 100644 index 737e1a76..00000000 --- a/modules/bsrgan_model.py +++ /dev/null @@ -1,76 +0,0 @@ -import os.path -import sys -import traceback - -import PIL.Image -import numpy as np -import torch -from basicsr.utils.download_util import load_file_from_url - -import modules.upscaler -from modules import devices, modelloader -from modules.bsrgan_model_arch import RRDBNet - - -class UpscalerBSRGAN(modules.upscaler.Upscaler): - def __init__(self, dirname): - self.name = "BSRGAN" - self.model_name = "BSRGAN 4x" - self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" - self.user_path = dirname - super().__init__() - model_paths = self.find_models(ext_filter=[".pt", ".pth"]) - scalers = [] - if len(model_paths) == 0: - scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4) - scalers.append(scaler_data) - for file in model_paths: - if "http" in file: - name = self.model_name - else: - name = modelloader.friendly_name(file) - try: - scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) - scalers.append(scaler_data) - except Exception: - print(f"Error loading BSRGAN model: {file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - self.scalers = scalers - - def do_upscale(self, img: PIL.Image, selected_file): - torch.cuda.empty_cache() - model = self.load_model(selected_file) - if model is None: - return img - model.to(devices.device_bsrgan) - torch.cuda.empty_cache() - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_bsrgan) - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - torch.cuda.empty_cache() - return PIL.Image.fromarray(output, 'RGB') - - def load_model(self, path: str): - if "http" in path: - filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, - progress=True) - else: - filename = path - if not os.path.exists(filename) or filename is None: - print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr) - return None - model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for k, v in model.named_parameters(): - v.requires_grad = False - return model - diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py deleted file mode 100644 index cb4d1c13..00000000 --- a/modules/bsrgan_model_arch.py +++ /dev/null @@ -1,102 +0,0 @@ -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.nn.init as init - - -def initialize_weights(net_l, scale=1): - if not isinstance(net_l, list): - net_l = [net_l] - for net in net_l: - for m in net.modules(): - if isinstance(m, nn.Conv2d): - init.kaiming_normal_(m.weight, a=0, mode='fan_in') - m.weight.data *= scale # for residual block - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - init.kaiming_normal_(m.weight, a=0, mode='fan_in') - m.weight.data *= scale - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - init.constant_(m.weight, 1) - init.constant_(m.bias.data, 0.0) - - -def make_layer(block, n_layers): - layers = [] - for _ in range(n_layers): - layers.append(block()) - return nn.Sequential(*layers) - - -class ResidualDenseBlock_5C(nn.Module): - def __init__(self, nf=64, gc=32, bias=True): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) - self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) - self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) - self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) - self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - # initialization - initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) - - def forward(self, x): - x1 = self.lrelu(self.conv1(x)) - x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) - x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 * 0.2 + x - - -class RRDB(nn.Module): - '''Residual in Residual Dense Block''' - - def __init__(self, nf, gc=32): - super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) - - def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - return out * 0.2 + x - - -class RRDBNet(nn.Module): - def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4): - super(RRDBNet, self).__init__() - RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) - self.sf = sf - - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.RRDB_trunk = make_layer(RRDB_block_f, nb) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - if self.sf==4: - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x): - fea = self.conv_first(x) - trunk = self.trunk_conv(self.RRDB_trunk(fea)) - fea = fea + trunk - - fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) - if self.sf==4: - fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) - out = self.conv_last(self.lrelu(self.HRconv(fea))) - - return out \ No newline at end of file diff --git a/modules/esrgam_model_arch.py b/modules/esrgam_model_arch.py deleted file mode 100644 index e413d36e..00000000 --- a/modules/esrgam_model_arch.py +++ /dev/null @@ -1,80 +0,0 @@ -# this file is taken from https://github.com/xinntao/ESRGAN - -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def make_layer(block, n_layers): - layers = [] - for _ in range(n_layers): - layers.append(block()) - return nn.Sequential(*layers) - - -class ResidualDenseBlock_5C(nn.Module): - def __init__(self, nf=64, gc=32, bias=True): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) - self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) - self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) - self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) - self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - # initialization - # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) - - def forward(self, x): - x1 = self.lrelu(self.conv1(x)) - x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) - x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 * 0.2 + x - - -class RRDB(nn.Module): - '''Residual in Residual Dense Block''' - - def __init__(self, nf, gc=32): - super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) - - def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - return out * 0.2 + x - - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32): - super(RRDBNet, self).__init__() - RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) - - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.RRDB_trunk = make_layer(RRDB_block_f, nb) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x): - fea = self.conv_first(x) - trunk = self.trunk_conv(self.RRDB_trunk(fea)) - fea = fea + trunk - - fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) - fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) - out = self.conv_last(self.lrelu(self.HRconv(fea))) - - return out diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 3970e6e4..a49e2258 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -5,68 +5,115 @@ import torch from PIL import Image from basicsr.utils.download_util import load_file_from_url -import modules.esrgam_model_arch as arch +import modules.esrgan_model_arch as arch from modules import shared, modelloader, images, devices from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts -def fix_model_layers(crt_model, pretrained_net): - # this code is adapted from https://github.com/xinntao/ESRGAN - if 'conv_first.weight' in pretrained_net: - return pretrained_net - if 'model.0.weight' not in pretrained_net: - is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"] - if is_realesrgan: - raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.") - else: - raise Exception("The file is not a ESRGAN model.") +def mod2normal(state_dict): + # this code is copied from https://github.com/victorca25/iNNfer + if 'conv_first.weight' in state_dict: + crt_net = {} + items = [] + for k, v in state_dict.items(): + items.append(k) + + crt_net['model.0.weight'] = state_dict['conv_first.weight'] + crt_net['model.0.bias'] = state_dict['conv_first.bias'] + + for k in items.copy(): + if 'RDB' in k: + ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[ori_k] = state_dict[k] + items.remove(k) + + crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight'] + crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias'] + crt_net['model.3.weight'] = state_dict['upconv1.weight'] + crt_net['model.3.bias'] = state_dict['upconv1.bias'] + crt_net['model.6.weight'] = state_dict['upconv2.weight'] + crt_net['model.6.bias'] = state_dict['upconv2.bias'] + crt_net['model.8.weight'] = state_dict['HRconv.weight'] + crt_net['model.8.bias'] = state_dict['HRconv.bias'] + crt_net['model.10.weight'] = state_dict['conv_last.weight'] + crt_net['model.10.bias'] = state_dict['conv_last.bias'] + state_dict = crt_net + return state_dict + + +def resrgan2normal(state_dict, nb=23): + # this code is copied from https://github.com/victorca25/iNNfer + if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: + crt_net = {} + items = [] + for k, v in state_dict.items(): + items.append(k) + + crt_net['model.0.weight'] = state_dict['conv_first.weight'] + crt_net['model.0.bias'] = state_dict['conv_first.bias'] + + for k in items.copy(): + if "rdb" in k: + ori_k = k.replace('body.', 'model.1.sub.') + ori_k = ori_k.replace('.rdb', '.RDB') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[ori_k] = state_dict[k] + items.remove(k) + + crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight'] + crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias'] + crt_net['model.3.weight'] = state_dict['conv_up1.weight'] + crt_net['model.3.bias'] = state_dict['conv_up1.bias'] + crt_net['model.6.weight'] = state_dict['conv_up2.weight'] + crt_net['model.6.bias'] = state_dict['conv_up2.bias'] + crt_net['model.8.weight'] = state_dict['conv_hr.weight'] + crt_net['model.8.bias'] = state_dict['conv_hr.bias'] + crt_net['model.10.weight'] = state_dict['conv_last.weight'] + crt_net['model.10.bias'] = state_dict['conv_last.bias'] + state_dict = crt_net + return state_dict + + +def infer_params(state_dict): + # this code is copied from https://github.com/victorca25/iNNfer + scale2x = 0 + scalemin = 6 + n_uplayer = 0 + plus = False + + for block in list(state_dict): + parts = block.split(".") + n_parts = len(parts) + if n_parts == 5 and parts[2] == "sub": + nb = int(parts[3]) + elif n_parts == 3: + part_num = int(parts[1]) + if (part_num > scalemin + and parts[0] == "model" + and parts[2] == "weight"): + scale2x += 1 + if part_num > n_uplayer: + n_uplayer = part_num + out_nc = state_dict[block].shape[0] + if not plus and "conv1x1" in block: + plus = True + + nf = state_dict["model.0.weight"].shape[0] + in_nc = state_dict["model.0.weight"].shape[1] + out_nc = out_nc + scale = 2 ** scale2x + + return in_nc, out_nc, nf, nb, plus, scale - crt_net = crt_model.state_dict() - load_net_clean = {} - for k, v in pretrained_net.items(): - if k.startswith('module.'): - load_net_clean[k[7:]] = v - else: - load_net_clean[k] = v - pretrained_net = load_net_clean - - tbd = [] - for k, v in crt_net.items(): - tbd.append(k) - - # directly copy - for k, v in crt_net.items(): - if k in pretrained_net and pretrained_net[k].size() == v.size(): - crt_net[k] = pretrained_net[k] - tbd.remove(k) - - crt_net['conv_first.weight'] = pretrained_net['model.0.weight'] - crt_net['conv_first.bias'] = pretrained_net['model.0.bias'] - - for k in tbd.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[k] = pretrained_net[ori_k] - tbd.remove(k) - - crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight'] - crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias'] - crt_net['upconv1.weight'] = pretrained_net['model.3.weight'] - crt_net['upconv1.bias'] = pretrained_net['model.3.bias'] - crt_net['upconv2.weight'] = pretrained_net['model.6.weight'] - crt_net['upconv2.bias'] = pretrained_net['model.6.bias'] - crt_net['HRconv.weight'] = pretrained_net['model.8.weight'] - crt_net['HRconv.bias'] = pretrained_net['model.8.bias'] - crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] - crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] - - return crt_net class UpscalerESRGAN(Upscaler): def __init__(self, dirname): @@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler): print("Unable to load %s from %s" % (self.model_path, filename)) return None - pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) - crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) + state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) + + if "params_ema" in state_dict: + state_dict = state_dict["params_ema"] + elif "params" in state_dict: + state_dict = state_dict["params"] + num_conv = 16 if "realesr-animevideov3" in filename else 32 + model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu') + model.load_state_dict(state_dict) + model.eval() + return model + + if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict: + nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23 + state_dict = resrgan2normal(state_dict, nb) + elif "conv_first.weight" in state_dict: + state_dict = mod2normal(state_dict) + elif "model.0.weight" not in state_dict: + raise Exception("The file is not a recognized ESRGAN model.") + + in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) - pretrained_net = fix_model_layers(crt_model, pretrained_net) - crt_model.load_state_dict(pretrained_net) - crt_model.eval() + model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus) + model.load_state_dict(state_dict) + model.eval() - return crt_model + return model def upscale_without_tiling(model, img): img = np.array(img) img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() img = img.unsqueeze(0).to(devices.device_esrgan) with torch.no_grad(): diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py new file mode 100644 index 00000000..bc9ceb2a --- /dev/null +++ b/modules/esrgan_model_arch.py @@ -0,0 +1,463 @@ +# this file is adapted from https://github.com/victorca25/iNNfer + +import math +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F + + +#################### +# RRDBNet Generator +#################### + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, + act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', + finalact=None, gaussian_noise=False, plus=False): + super(RRDBNet, self).__init__() + n_upscale = int(math.log(upscale, 2)) + if upscale == 3: + n_upscale = 1 + + self.resrgan_scale = 0 + if in_nc % 16 == 0: + self.resrgan_scale = 1 + elif in_nc != 4 and in_nc % 4 == 0: + self.resrgan_scale = 2 + + fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) + rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', + norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, + gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)] + LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype) + + if upsample_mode == 'upconv': + upsample_block = upconv_block + elif upsample_mode == 'pixelshuffle': + upsample_block = pixelshuffle_block + else: + raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + if upscale == 3: + upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype) + else: + upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)] + HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype) + HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) + + outact = act(finalact) if finalact else None + + self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)), + *upsampler, HR_conv0, HR_conv1, outact) + + def forward(self, x, outm=None): + if self.resrgan_scale == 1: + feat = pixel_unshuffle(x, scale=4) + elif self.resrgan_scale == 2: + feat = pixel_unshuffle(x, scale=2) + else: + feat = x + + return self.model(feat) + + +class RRDB(nn.Module): + """ + Residual in Residual Dense Block + (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) + """ + + def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', + norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', + spectral_norm=False, gaussian_noise=False, plus=False): + super(RRDB, self).__init__() + # This is for backwards compatibility with existing models + if nr == 3: + self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) + self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) + self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) + else: + RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)] + self.RDBs = nn.Sequential(*RDB_list) + + def forward(self, x): + if hasattr(self, 'RDB1'): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + else: + out = self.RDBs(x) + return out * 0.2 + x + + +class ResidualDenseBlock_5C(nn.Module): + """ + Residual Dense Block + The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) + Modified options that can be used: + - "Partial Convolution based Padding" arXiv:1811.11718 + - "Spectral normalization" arXiv:1802.05957 + - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. + {Rakotonirina} and A. {Rasoanaivo} + """ + + def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', + norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', + spectral_norm=False, gaussian_noise=False, plus=False): + super(ResidualDenseBlock_5C, self).__init__() + + self.noise = GaussianNoise() if gaussian_noise else None + self.conv1x1 = conv1x1(nf, gc) if plus else None + + self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + if mode == 'CNA': + last_act = None + else: + last_act = act_type + self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(torch.cat((x, x1), 1)) + if self.conv1x1: + x2 = x2 + self.conv1x1(x) + x3 = self.conv3(torch.cat((x, x1, x2), 1)) + x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) + if self.conv1x1: + x4 = x4 + x2 + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + if self.noise: + return self.noise(x5.mul(0.2) + x) + else: + return x5 * 0.2 + x + + +#################### +# ESRGANplus +#################### + +class GaussianNoise(nn.Module): + def __init__(self, sigma=0.1, is_relative_detach=False): + super().__init__() + self.sigma = sigma + self.is_relative_detach = is_relative_detach + self.noise = torch.tensor(0, dtype=torch.float) + + def forward(self, x): + if self.training and self.sigma != 0: + self.noise = self.noise.to(x.device) + scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x + sampled_noise = self.noise.repeat(*x.size()).normal_() * scale + x = x + sampled_noise + return x + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +#################### +# SRVGGNetCompact +#################### + +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + This class is copied from https://github.com/xinntao/Real-ESRGAN + """ + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): + super(SRVGGNetCompact, self).__init__() + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_conv = num_conv + self.upscale = upscale + self.act_type = act_type + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) + # the first activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the body structure + for _ in range(num_conv): + self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) + # activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(upscale) + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') + out += base + return out + + +#################### +# Upsampler +#################### + +class Upsample(nn.Module): + r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. + The input data is assumed to be of the form + `minibatch x channels x [optional depth] x [optional height] x width`. + """ + + def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): + super(Upsample, self).__init__() + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.size = size + self.align_corners = align_corners + + def forward(self, x): + return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + + def extra_repr(self): + if self.scale_factor is not None: + info = 'scale_factor=' + str(self.scale_factor) + else: + info = 'size=' + str(self.size) + info += ', mode=' + self.mode + return info + + +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, + pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'): + """ + Pixel shuffle layer + (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional + Neural Network, CVPR17) + """ + conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, + pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + + n = norm(norm_type, out_nc) if norm_type else None + a = act(act_type) if act_type else None + return sequential(conv, pixel_shuffle, n, a) + + +def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, + pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'): + """ Upconv layer """ + upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor + upsample = Upsample(scale_factor=upscale_factor, mode=mode) + conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, + pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype) + return sequential(upsample, conv) + + + + + + + + +#################### +# Basic blocks +#################### + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + Args: + basic_block (nn.module): nn.module class for basic block. (block) + num_basic_block (int): number of blocks. (n_layers) + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): + """ activation helper """ + act_type = act_type.lower() + if act_type == 'relu': + layer = nn.ReLU(inplace) + elif act_type in ('leakyrelu', 'lrelu'): + layer = nn.LeakyReLU(neg_slope, inplace) + elif act_type == 'prelu': + layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) + elif act_type == 'tanh': # [-1, 1] range output + layer = nn.Tanh() + elif act_type == 'sigmoid': # [0, 1] range output + layer = nn.Sigmoid() + else: + raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) + return layer + + +class Identity(nn.Module): + def __init__(self, *kwargs): + super(Identity, self).__init__() + + def forward(self, x, *kwargs): + return x + + +def norm(norm_type, nc): + """ Return a normalization layer """ + norm_type = norm_type.lower() + if norm_type == 'batch': + layer = nn.BatchNorm2d(nc, affine=True) + elif norm_type == 'instance': + layer = nn.InstanceNorm2d(nc, affine=False) + elif norm_type == 'none': + def norm_layer(x): return Identity() + else: + raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type)) + return layer + + +def pad(pad_type, padding): + """ padding layer helper """ + pad_type = pad_type.lower() + if padding == 0: + return None + if pad_type == 'reflect': + layer = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + layer = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + layer = nn.ZeroPad2d(padding) + else: + raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type)) + return layer + + +def get_valid_padding(kernel_size, dilation): + kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) + padding = (kernel_size - 1) // 2 + return padding + + +class ShortcutBlock(nn.Module): + """ Elementwise sum the output of a submodule to its input """ + def __init__(self, submodule): + super(ShortcutBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = x + self.sub(x) + return output + + def __repr__(self): + return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|') + + +def sequential(*args): + """ Flatten Sequential. It unwraps nn.Sequential. """ + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError('sequential does not support OrderedDict input.') + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, + pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', + spectral_norm=False): + """ Conv layer with padding, normalization, activation """ + assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode) + padding = get_valid_padding(kernel_size, dilation) + p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None + padding = padding if pad_type == 'zero' else 0 + + if convtype=='PartialConv2D': + c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + elif convtype=='DeformConv2D': + c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + elif convtype=='Conv3D': + c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + else: + c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + + if spectral_norm: + c = nn.utils.spectral_norm(c) + + a = act(act_type) if act_type else None + if 'CNA' in mode: + n = norm(norm_type, out_nc) if norm_type else None + return sequential(p, c, n, a) + elif mode == 'NAC': + if norm_type is None and act_type is not None: + a = act(act_type, inplace=False) + n = norm(norm_type, in_nc) if norm_type else None + return sequential(n, a, p, c) -- cgit v1.2.3 From 542a3d3a4a00c1383fbdaf938ceefef87cf834bb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 14:33:22 +0300 Subject: fix btoken hypernetworks in XY plot --- modules/hypernetwork.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 19f1c227..498bc9d8 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -49,15 +49,18 @@ def list_hypernetworks(path): def load_hypernetwork(filename): - print(f"Loading hypernetwork {filename}") path = shared.hypernetworks.get(filename, None) - if (path is not None): + if path is not None: + print(f"Loading hypernetwork {filename}") try: shared.loaded_hypernetwork = Hypernetwork(path) except Exception: print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") + shared.loaded_hypernetwork = None -- cgit v1.2.3 From d6d10a37bfd21568e74efb46137f906da96d5fdb Mon Sep 17 00:00:00 2001 From: William Moorehouse Date: Sun, 9 Oct 2022 04:58:40 -0400 Subject: Added extended model details to infotext --- modules/processing.py | 3 +++ modules/sd_models.py | 3 ++- modules/shared.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 2c991317..d1bcee4a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -284,6 +284,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), + "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_name else shared.sd_model.sd_model_name), + "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name), + "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), diff --git a/modules/sd_models.py b/modules/sd_models.py index d0c74dd8..3fa42329 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,7 @@ import sys from collections import namedtuple import torch from omegaconf import OmegaConf - +from pathlib import Path from ldm.util import instantiate_from_config @@ -158,6 +158,7 @@ def load_model_weights(model, checkpoint_info): vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} model.first_stage_model.load_state_dict(vae_dict) + model.sd_model_vae_name = Path(vae_file).stem model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file diff --git a/modules/shared.py b/modules/shared.py index dffa0094..ca63f7d8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -242,6 +242,7 @@ options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), + "add_extended_model_details_to_info": OptionInfo(False, "Add extended model details to generation information (model name, VAE, hypernetwork)"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), -- cgit v1.2.3 From 006791c13d70e582eee766b7d0499e9821a86bf9 Mon Sep 17 00:00:00 2001 From: William Moorehouse Date: Sun, 9 Oct 2022 05:09:18 -0400 Subject: Fix grabbing the model name for infotext --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index d1bcee4a..c035c990 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -284,7 +284,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_name else shared.sd_model.sd_model_name), + "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name), "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name), "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork), "Batch size": (None if p.batch_size < 2 else p.batch_size), -- cgit v1.2.3 From 594cbfd8fbe4078b43ceccf01509eeef3d6790c6 Mon Sep 17 00:00:00 2001 From: William Moorehouse Date: Sun, 9 Oct 2022 07:27:11 -0400 Subject: Sanitize infotext output (for now) --- modules/processing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index c035c990..049f3769 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -284,9 +284,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name), - "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name), - "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork), + "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), + "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name.replace(',', '').replace(':', '')), + "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork.replace(',', '').replace(':', '')), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), -- cgit v1.2.3 From e6e8cabe0c9c335e0d72345602c069b198558b53 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 14:57:48 +0300 Subject: change up #2056 to make it work how i want it to plus make xy plot write correct values to images --- modules/processing.py | 5 ++--- modules/sd_models.py | 2 -- modules/shared.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 049f3769..04aed989 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -284,9 +284,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name.replace(',', '').replace(':', '')), - "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork.replace(',', '').replace(':', '')), + "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), + "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), diff --git a/modules/sd_models.py b/modules/sd_models.py index 3fa42329..e63d3c29 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,6 @@ import sys from collections import namedtuple import torch from omegaconf import OmegaConf -from pathlib import Path from ldm.util import instantiate_from_config @@ -158,7 +157,6 @@ def load_model_weights(model, checkpoint_info): vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} model.first_stage_model.load_state_dict(vae_dict) - model.sd_model_vae_name = Path(vae_file).stem model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file diff --git a/modules/shared.py b/modules/shared.py index ca63f7d8..6ecc2503 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -242,7 +242,7 @@ options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "add_extended_model_details_to_info": OptionInfo(False, "Add extended model details to generation information (model name, VAE, hypernetwork)"), + "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), -- cgit v1.2.3 From 9d1138e2940c4ddcd2685bcba12c7d407e9e0ec5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 15:08:10 +0300 Subject: fix typo in filename for ESRGAN arch --- modules/esrgam_model_arch.py | 80 -------------------------------------------- modules/esrgan_model.py | 2 +- modules/esrgan_model_arch.py | 80 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 81 deletions(-) delete mode 100644 modules/esrgam_model_arch.py create mode 100644 modules/esrgan_model_arch.py (limited to 'modules') diff --git a/modules/esrgam_model_arch.py b/modules/esrgam_model_arch.py deleted file mode 100644 index e413d36e..00000000 --- a/modules/esrgam_model_arch.py +++ /dev/null @@ -1,80 +0,0 @@ -# this file is taken from https://github.com/xinntao/ESRGAN - -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def make_layer(block, n_layers): - layers = [] - for _ in range(n_layers): - layers.append(block()) - return nn.Sequential(*layers) - - -class ResidualDenseBlock_5C(nn.Module): - def __init__(self, nf=64, gc=32, bias=True): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) - self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) - self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) - self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) - self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - # initialization - # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) - - def forward(self, x): - x1 = self.lrelu(self.conv1(x)) - x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) - x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 * 0.2 + x - - -class RRDB(nn.Module): - '''Residual in Residual Dense Block''' - - def __init__(self, nf, gc=32): - super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) - - def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - return out * 0.2 + x - - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32): - super(RRDBNet, self).__init__() - RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) - - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.RRDB_trunk = make_layer(RRDB_block_f, nb) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x): - fea = self.conv_first(x) - trunk = self.trunk_conv(self.RRDB_trunk(fea)) - fea = fea + trunk - - fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) - fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) - out = self.conv_last(self.lrelu(self.HRconv(fea))) - - return out diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 3970e6e4..46ad0da3 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -5,7 +5,7 @@ import torch from PIL import Image from basicsr.utils.download_util import load_file_from_url -import modules.esrgam_model_arch as arch +import modules.esrgan_model_arch as arch from modules import shared, modelloader, images, devices from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py new file mode 100644 index 00000000..e413d36e --- /dev/null +++ b/modules/esrgan_model_arch.py @@ -0,0 +1,80 @@ +# this file is taken from https://github.com/xinntao/ESRGAN + +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out -- cgit v1.2.3 From 875ddfeecfaffad9eee24813301637cba310337d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 17:58:43 +0300 Subject: added guard for torch.load to prevent loading pickles with unknown content --- modules/paths.py | 1 + modules/safe.py | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + 3 files changed, 91 insertions(+) create mode 100644 modules/safe.py (limited to 'modules') diff --git a/modules/paths.py b/modules/paths.py index 0519caa0..1e7a2fbc 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -1,6 +1,7 @@ import argparse import os import sys +import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) models_path = os.path.join(script_path, "models") diff --git a/modules/safe.py b/modules/safe.py new file mode 100644 index 00000000..2d2c1371 --- /dev/null +++ b/modules/safe.py @@ -0,0 +1,89 @@ +# this code is adapted from the script contributed by anon from /h/ + +import io +import pickle +import collections +import sys +import traceback + +import torch +import numpy +import _codecs +import zipfile + + +def encode(*args): + out = _codecs.encode(*args) + return out + + +class RestrictedUnpickler(pickle.Unpickler): + def persistent_load(self, saved_id): + assert saved_id[0] == 'storage' + return torch.storage._TypedStorage() + + def find_class(self, module, name): + if module == 'collections' and name == 'OrderedDict': + return getattr(collections, name) + if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: + return getattr(torch._utils, name) + if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage']: + return getattr(torch, name) + if module == 'torch.nn.modules.container' and name in ['ParameterDict']: + return getattr(torch.nn.modules.container, name) + if module == 'numpy.core.multiarray' and name == 'scalar': + return numpy.core.multiarray.scalar + if module == 'numpy' and name == 'dtype': + return numpy.dtype + if module == '_codecs' and name == 'encode': + return encode + if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': + import pytorch_lightning.callbacks + return pytorch_lightning.callbacks.model_checkpoint + if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': + import pytorch_lightning.callbacks.model_checkpoint + return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint + if module == "__builtin__" and name == 'set': + return set + + # Forbid everything else. + raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") + + +def check_pt(filename): + try: + + # new pytorch format is a zip file + with zipfile.ZipFile(filename) as z: + with z.open('archive/data.pkl') as file: + unpickler = RestrictedUnpickler(file) + unpickler.load() + + except zipfile.BadZipfile: + + # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle + with open(filename, "rb") as file: + unpickler = RestrictedUnpickler(file) + for i in range(5): + unpickler.load() + + +def load(filename, *args, **kwargs): + from modules import shared + + try: + if not shared.cmd_opts.disable_safe_unpickle: + check_pt(filename) + + except Exception: + print(f"Error verifying pickled file from {filename}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) + print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr) + return None + + return unsafe_torch_load(filename, *args, **kwargs) + + +unsafe_torch_load = torch.load +torch.load = load diff --git a/modules/shared.py b/modules/shared.py index 6ecc2503..3d7f08e1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -65,6 +65,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) cmd_opts = parser.parse_args() -- cgit v1.2.3 From 9ecea0a8d6bdc434755e11128487fd62f1ff130f Mon Sep 17 00:00:00 2001 From: Artem Zagidulin Date: Sun, 9 Oct 2022 16:14:56 +0300 Subject: fix missing png info when Extras Batch Process --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 39dd3806..41e8612c 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -29,7 +29,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v if extras_mode == 1: #convert file to pillow image for img in image_folder: - image = Image.fromarray(np.array(Image.open(img))) + image = Image.open(img) imageArr.append(image) imageNameArr.append(os.path.splitext(img.orig_name)[0]) else: -- cgit v1.2.3 From 6c383d2e82045fc4475d665f83bdeeac8fd844d9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 22:24:07 +0300 Subject: show model selection setting on top of page --- modules/shared.py | 5 +++-- modules/ui.py | 54 +++++++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 3d7f08e1..270fa402 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -131,13 +131,14 @@ def realesrgan_models_names(): class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False): self.default = default self.label = label self.component = component self.component_args = component_args self.onchange = onchange self.section = None + self.show_on_main_page = show_on_main_page def options_section(section_identifier, options_dict): @@ -214,7 +215,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), diff --git a/modules/ui.py b/modules/ui.py index dad509f3..2231a8ed 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1175,10 +1175,13 @@ Requested path was: {f} changed = 0 for key, value, comp in zip(opts.data_labels.keys(), args, components): - if not opts.same_type(value, opts.data_labels[key].default): - return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default): + return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson() for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + comp_args = opts.data_labels[key].component_args if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: continue @@ -1196,6 +1199,21 @@ Requested path was: {f} return f'{changed} settings changed.', opts.dumpjson() + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + oldval = opts.data.get(key, None) + opts.data[key] = value + + if oldval != value: + if opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + + return gr.update(value=value), opts.dumpjson() + with gr.Blocks(analytics_enabled=False) as settings_interface: settings_submit = gr.Button(value="Apply settings", variant='primary') result = gr.HTML() @@ -1203,6 +1221,8 @@ Requested path was: {f} settings_cols = 3 items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) + quicksettings_list = [] + cols_displayed = 0 items_displayed = 0 previous_section = None @@ -1225,10 +1245,14 @@ Requested path was: {f} gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='

{}

'.format(item.section[1])) - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - items_displayed += 1 + if item.show_on_main_page: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + items_displayed += 1 request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") request_notifications.click( @@ -1242,7 +1266,6 @@ Requested path was: {f} reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - def reload_scripts(): modules.scripts.reload_script_body_only() @@ -1289,7 +1312,11 @@ Requested path was: {f} css += css_hide_progressbar with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - + with gr.Row(elem_id="quicksettings"): + for i, k, item in quicksettings_list: + component = create_setting_component(k) + component_dict[k] = component + settings_interface.gradio_ref = demo with gr.Tabs() as tabs: @@ -1306,7 +1333,16 @@ Requested path was: {f} inputs=components, outputs=[result, text_settings], ) - + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + def modelmerger(*args): try: results = modules.extras.run_modelmerger(*args) -- cgit v1.2.3 From e59c66c0088422b27f64b401ef42c242f836725a Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 16:32:05 -0400 Subject: Optimized code for Ignoring last CLIP layers --- modules/sd_hijack.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f12a9696..4a2d2153 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -282,14 +282,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - tmp = -opts.CLIP_ignore_last_layers - if (opts.CLIP_ignore_last_layers == 0): - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) - z = outputs.last_hidden_state - else: - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) - z = outputs.hidden_states[tmp] - z = self.wrapped.transformer.text_model.final_layer_norm(z) + tmp = -opts.CLIP_stop_at_last_layers + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) + z = outputs.hidden_states[tmp] + z = self.wrapped.transformer.text_model.final_layer_norm(z) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] -- cgit v1.2.3 From a14f7bf113a2af9e06a1c4d06c2efa244f9c5730 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 16:33:06 -0400 Subject: Corrected CLIP Layer Ignore description and updated its range to the max possible --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 270fa402..1995a99a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -225,7 +225,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), - 'CLIP_ignore_last_layers': OptionInfo(0, "Ignore last layers of CLIP model", gr.Slider, {"minimum": 0, "maximum": 5, "step": 1}), + 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.3 From ec2bd9be75865c9f3a8c898163ab381688c03b6e Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 17:28:42 -0400 Subject: Fix issues with CLIP ignore option name change --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 04aed989..92a105a2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -129,7 +129,7 @@ class Processed: self.index_of_first_image = index_of_first_image self.styles = p.styles self.job_timestamp = state.job_timestamp - self.clip_skip = opts.CLIP_ignore_last_layers + self.clip_skip = opts.CLIP_stop_at_last_layers self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -274,7 +274,7 @@ def fix_seed(p): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size - clip_skip = getattr(p, 'clip_skip', opts.CLIP_ignore_last_layers) + clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) generation_params = { "Steps": p.steps, -- cgit v1.2.3 From ad3ae441081155dcd4fde805279e5082ca264695 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sun, 9 Oct 2022 04:32:40 -0400 Subject: Updated code for legibility --- modules/sd_hijack.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4a2d2153..7793d25b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -284,8 +284,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): tmp = -opts.CLIP_stop_at_last_layers outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) - z = outputs.hidden_states[tmp] - z = self.wrapped.transformer.text_model.final_layer_norm(z) + if tmp < -1: + z = outputs.hidden_states[tmp] + z = self.wrapped.transformer.text_model.final_layer_norm(z) + else: + z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] -- cgit v1.2.3 From 1824e9ee3ab4f94aee8908a62ea2569a01aeb3d7 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sun, 9 Oct 2022 14:15:43 -0400 Subject: Removed unnecessary tmp variable --- modules/sd_hijack.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 7793d25b..437acce4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -282,10 +282,9 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - tmp = -opts.CLIP_stop_at_last_layers - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) - if tmp < -1: - z = outputs.hidden_states[tmp] + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers) + if opts.CLIP_stop_at_last_layers > 1: + z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] z = self.wrapped.transformer.text_model.final_layer_norm(z) else: z = outputs.last_hidden_state -- cgit v1.2.3 From 8d340cfb884e1dbff5b6f477f4ecf7d104279115 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 22:30:59 +0300 Subject: do not add clip skip to parameters if it's 1 or 0 --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 92a105a2..94d2dd62 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -293,7 +293,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), - "Clip skip": None if clip_skip==0 else clip_skip, + "Clip skip": None if clip_skip <= 1 else clip_skip, } generation_params.update(p.extra_generation_params) -- cgit v1.2.3 From fa0c5eb81b72bc1e562d0b9bbd92f30945d78b4e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 20:41:22 +0100 Subject: Add pretty image captioning functions --- modules/images.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 29c5ee24..10963dc7 100644 --- a/modules/images.py +++ b/modules/images.py @@ -428,3 +428,34 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i file.write(info + "\n") return fullfn + +def addCaptionLines(lines,image,initialx,textfont): + draw = ImageDraw.Draw(image) + hstart =initialx + for fill,line in lines: + fontSize = 32 + font = ImageFont.truetype(textfont, fontSize) + _,_,w, h = draw.textbbox((0,0),line,font=font) + fontSize = min( int(fontSize * ((image.size[0]-35)/w) ), 28) + font = ImageFont.truetype(textfont, fontSize) + _,_,w,h = draw.textbbox((0,0),line,font=font) + draw.text(((image.size[0]-w)/2,hstart), line, font=font, fill=fill) + hstart += h + return hstart + +def captionImge(image,prelines,postlines,background=(51, 51, 51),font=None): + if font is None: + try: + font = ImageFont.truetype(opts.font or Roboto, fontsize) + font = opts.font or Roboto + except Exception: + font = Roboto + + sampleImage = image + background = Image.new("RGBA", (sampleImage.size[0],sampleImage.size[1]+1024), background) + hoffset = addCaptionLines(prelines,background,5,font)+16 + background.paste(sampleImage,(0,hoffset)) + hoffset = hoffset+sampleImage.size[1]+8 + hoffset = addCaptionLines(postlines,background,hoffset,font) + background = background.crop((0,0,sampleImage.size[0],hoffset+8)) + return background -- cgit v1.2.3 From a65476718f08a35f527b973ef731e6f488bace5e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 23:38:49 +0300 Subject: add DoubleStorage to list of allowed classes for pickle --- modules/safe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/safe.py b/modules/safe.py index 2d2c1371..4d06f2a5 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -27,7 +27,7 @@ class RestrictedUnpickler(pickle.Unpickler): return getattr(collections, name) if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: return getattr(torch._utils, name) - if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage']: + if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']: return getattr(torch, name) if module == 'torch.nn.modules.container' and name in ['ParameterDict']: return getattr(torch.nn.modules.container, name) -- cgit v1.2.3 From 03694e1f9915e34cf7d9a31073f1a1a9def2909f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 21:58:14 +0100 Subject: add embedding load and save from b64 json --- modules/textual_inversion/textual_inversion.py | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index f6316020..1b7f8906 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,9 +7,11 @@ import tqdm import html import datetime -from PIL import Image, PngImagePlugin +from PIL import Image,PngImagePlugin +from ..images import captionImge +import numpy as np import base64 -from io import BytesIO +import json from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset @@ -87,9 +89,9 @@ class EmbeddingDatabase: if filename.upper().endswith('.PNG'): embed_image = Image.open(path) - if 'sd-embedding' in embed_image.text: - embeddingData = base64.b64decode(embed_image.text['sd-embedding']) - data = torch.load(BytesIO(embeddingData), map_location="cpu") + if 'sd-ti-embedding' in embed_image.text: + data = embeddingFromB64(embed_image.text['sd-ti-embedding']) + name = data.get('name',name) else: data = torch.load(path, map_location="cpu") @@ -258,13 +260,23 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, if save_image_with_stored_embedding: info = PngImagePlugin.PngInfo() - info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read())) - image.save(last_saved_image, "PNG", pnginfo=info) + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embeddingToB64(data)) + + pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] + + caption_checkpoint_hash = data.get('sd_checkpoint','UNK') + caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK' + caption_stepcount = data.get('step',0) + caption_stepcount = caption_stepcount if caption_stepcount else 0 + + post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash, + caption_stepcount))] + captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines) + captioned_image.save(last_saved_image, "PNG", pnginfo=info) else: image.save(last_saved_image) - - last_saved_image += f", prompt: {text}" shared.state.job_no = embedding.step -- cgit v1.2.3 From 969bd8256e5b4f1007d3cc653723d4ad50a92528 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:02:28 +0100 Subject: add alternate checkpoint hash source --- modules/textual_inversion/textual_inversion.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 1b7f8906..d7813084 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -265,8 +265,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] - caption_checkpoint_hash = data.get('sd_checkpoint','UNK') - caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK' + caption_checkpoint_hash = data.get('sd_checkpoint') + if caption_checkpoint_hash is None: + caption_checkpoint_hash = data.get('hash') + caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNKNOWN' + caption_stepcount = data.get('step',0) caption_stepcount = caption_stepcount if caption_stepcount else 0 -- cgit v1.2.3 From 5d12ec82d3e13f5ff4c55db2930e4e10aed7015a Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:05:09 +0100 Subject: add encoder and decoder classes --- modules/textual_inversion/textual_inversion.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d7813084..44d4e08b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -16,6 +16,27 @@ import json from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset +class EmbeddingEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, torch.Tensor): + return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()} + return json.JSONEncoder.default(self, o) + +class EmbeddingDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + def object_hook(self, d): + if 'EMBEDDINGTENSOR' in d: + return torch.from_numpy(np.array(d['EMBEDDINGTENSOR'])) + return d + +def embeddingToB64(data): + d = json.dumps(data,cls=EmbeddingEncoder) + return base64.b64encode(d.encode()) + +def EmbeddingFromB64(data): + d = base64.b64decode(data) + return json.loads(d,cls=EmbeddingDecoder) class Embedding: def __init__(self, vec, name, step=None): -- cgit v1.2.3 From d0184b8f76ce492da699f1926f34b57cd095242e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:06:12 +0100 Subject: change json tensor key name --- modules/textual_inversion/textual_inversion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 44d4e08b..ae8d207d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -19,15 +19,15 @@ import modules.textual_inversion.dataset class EmbeddingEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, torch.Tensor): - return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()} + return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} return json.JSONEncoder.default(self, o) class EmbeddingDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) def object_hook(self, d): - if 'EMBEDDINGTENSOR' in d: - return torch.from_numpy(np.array(d['EMBEDDINGTENSOR'])) + if 'TORCHTENSOR' in d: + return torch.from_numpy(np.array(d['TORCHTENSOR'])) return d def embeddingToB64(data): -- cgit v1.2.3 From 66846105103cfc282434d0dc2102910160b7a633 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:06:42 +0100 Subject: correct case on embeddingFromB64 --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ae8d207d..d2b95fa3 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -34,7 +34,7 @@ def embeddingToB64(data): d = json.dumps(data,cls=EmbeddingEncoder) return base64.b64encode(d.encode()) -def EmbeddingFromB64(data): +def embeddingFromB64(data): d = base64.b64decode(data) return json.loads(d,cls=EmbeddingDecoder) -- cgit v1.2.3 From 96f1e6be59316ec640cab2435fa95b3688194906 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:14:50 +0100 Subject: source checkpoint hash from current checkpoint --- modules/textual_inversion/textual_inversion.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d2b95fa3..b16fa84e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -286,10 +286,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] - caption_checkpoint_hash = data.get('sd_checkpoint') - if caption_checkpoint_hash is None: - caption_checkpoint_hash = data.get('hash') - caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNKNOWN' + checkpoint = sd_models.select_checkpoint() + caption_checkpoint_hash = checkpoint.hash caption_stepcount = data.get('step',0) caption_stepcount = caption_stepcount if caption_stepcount else 0 -- cgit v1.2.3 From 01fd9cf0d28d8b71a113ab1aa62accfe7f0d9c51 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:17:02 +0100 Subject: change source of step count --- modules/textual_inversion/textual_inversion.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b16fa84e..e4f339b8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -285,15 +285,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, info.add_text("sd-ti-embedding", embeddingToB64(data)) pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] - checkpoint = sd_models.select_checkpoint() - caption_checkpoint_hash = checkpoint.hash - - caption_stepcount = data.get('step',0) - caption_stepcount = caption_stepcount if caption_stepcount else 0 - - post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash, - caption_stepcount))] + post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash, + embedding.step))] captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines) captioned_image.save(last_saved_image, "PNG", pnginfo=info) else: -- cgit v1.2.3 From 0ac3a07eecbd7b98f3a19d01dc46f02dcda3443b Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 00:05:36 +0100 Subject: add caption image with overlay --- modules/images.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 10963dc7..4a4fc977 100644 --- a/modules/images.py +++ b/modules/images.py @@ -459,3 +459,49 @@ def captionImge(image,prelines,postlines,background=(51, 51, 51),font=None): hoffset = addCaptionLines(postlines,background,hoffset,font) background = background.crop((0,0,sampleImage.size[0],hoffset+8)) return background + +def captionImageOverlay(srcimage,title,footerLeft,footerMid,footerRight,textfont=None): + from math import cos + + image = srcimage.copy() + + if textfont is None: + try: + textfont = ImageFont.truetype(opts.font or Roboto, fontsize) + textfont = opts.font or Roboto + except Exception: + textfont = Roboto + + factor = 1.5 + gradient = Image.new('RGBA', (1,image.size[1]), color=(0,0,0,0)) + for y in range(image.size[1]): + mag = 1-cos(y/image.size[1]*factor) + mag = max(mag,1-cos((image.size[1]-y)/image.size[1]*factor*1.1)) + gradient.putpixel((0, y), (0,0,0,int(mag*255))) + image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) + + draw = ImageDraw.Draw(image) + fontSize = 32 + font = ImageFont.truetype(textfont, fontSize) + padding = 10 + + _,_,w, h = draw.textbbox((0,0),title,font=font) + fontSize = min( int(fontSize * (((image.size[0]*0.75)-(padding*4))/w) ), 72) + font = ImageFont.truetype(textfont, fontSize) + _,_,w,h = draw.textbbox((0,0),title,font=font) + draw.text((padding,padding), title, anchor='lt', font=font, fill=(255,255,255,230)) + + _,_,w, h = draw.textbbox((0,0),footerLeft,font=font) + fontSizeleft = min( int(fontSize * (((image.size[0]/3)-(padding))/w) ), 72) + _,_,w, h = draw.textbbox((0,0),footerMid,font=font) + fontSizemid = min( int(fontSize * (((image.size[0]/3)-(padding))/w) ), 72) + _,_,w, h = draw.textbbox((0,0),footerRight,font=font) + fontSizeright = min( int(fontSize * (((image.size[0]/3)-(padding))/w) ), 72) + + font = ImageFont.truetype(textfont, min(fontSizeleft,fontSizemid,fontSizeright)) + + draw.text((padding,image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255,255,255,230)) + draw.text((image.size[0]/2,image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255,255,255,230)) + draw.text((image.size[0]-padding,image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255,255,255,230)) + + return image -- cgit v1.2.3 From d6a599ef9ba18a66ae79b50f2945af5788fdda8f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 00:07:52 +0100 Subject: change caption method --- modules/textual_inversion/textual_inversion.py | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e4f339b8..21596e78 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -8,7 +8,7 @@ import html import datetime from PIL import Image,PngImagePlugin -from ..images import captionImge +from ..images import captionImageOverlay import numpy as np import base64 import json @@ -212,6 +212,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, else: images_dir = None + if create_image_every > 0 and save_image_with_stored_embedding: + images_embeds_dir = os.path.join(log_directory, "image_embeddings") + os.makedirs(images_embeds_dir, exist_ok=True) + else: + images_embeds_dir = None + cond_model = shared.sd_model.cond_stage_model shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." @@ -279,19 +285,25 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, shared.state.current_image = image - if save_image_with_stored_embedding: + if save_image_with_stored_embedding and os.path.exists(last_saved_file): + + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') + info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) info.add_text("sd-ti-embedding", embeddingToB64(data)) - pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))] + title = "<{}>".format(data.get('name','???')) checkpoint = sd_models.select_checkpoint() - post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(checkpoint.hash, - embedding.step))] - captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines) - captioned_image.save(last_saved_image, "PNG", pnginfo=info) - else: - image.save(last_saved_image) + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '[{}]'.format(embedding.step) + + captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) + + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + + image.save(last_saved_image) last_saved_image += f", prompt: {text}" -- cgit v1.2.3 From e2c2925eb4d634b186de2c76798162ec56e2f869 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 00:12:53 +0100 Subject: remove braces from steps --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 21596e78..9a18ee5c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -297,7 +297,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '[{}]'.format(embedding.step) + footer_right = '{}'.format(embedding.step) captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) -- cgit v1.2.3 From 6435691bb11c5a35703720bfd2a875f24c066f86 Mon Sep 17 00:00:00 2001 From: Justin Maier Date: Sun, 9 Oct 2022 19:26:52 -0600 Subject: Add "Scale to" option to Extras --- modules/extras.py | 28 +++++++++++++++++++++++----- modules/ui.py | 38 +++++++++++++++++++++++++------------- 2 files changed, 48 insertions(+), 18 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 41e8612c..83ca7049 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,3 +1,4 @@ +import math import os import numpy as np @@ -19,7 +20,7 @@ import gradio as gr cached_images = {} -def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): +def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): devices.torch_gc() imageArr = [] @@ -67,8 +68,23 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" image = res + if resize_mode == 1: + upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) + crop_info = " (crop)" if upscaling_crop else "" + info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" + + def crop_upscaled_center(image, resize_w, resize_h): + left = int(math.ceil((image.width - resize_w) / 2)) + right = image.width - int(math.floor((image.width - resize_w) / 2)) + top = int(math.ceil((image.height - resize_h) / 2)) + bottom = image.height - int(math.floor((image.height - resize_h) / 2)) + + image = image.crop((left, top, right, bottom)) + return image + + if upscaling_resize != 1.0: - def upscale(image, scaler_index, resize): + def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) pixels = tuple(np.array(small).flatten().tolist()) key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels @@ -77,15 +93,17 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v if c is None: upscaler = shared.sd_upscalers[scaler_index] c = upscaler.scaler.upscale(image, resize, upscaler.data_path) + if mode == 1 and crop: + c = crop_upscaled_center(c, resize_w, resize_h) cached_images[key] = c return c info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" - res = upscale(image, extras_upscaler_1, upscaling_resize) + res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - res2 = upscale(image, extras_upscaler_2, upscaling_resize) + res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" res = Image.blend(res, res2, extras_upscaler_2_visibility) @@ -190,7 +208,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint if save_as_half: theta_0[key] = theta_0[key].half() - + for key in theta_1.keys(): if 'model' in key and key not in theta_0: theta_0[key] = theta_1[key] diff --git a/modules/ui.py b/modules/ui.py index 2231a8ed..4bb2892b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -101,7 +101,7 @@ def send_gradio_gallery_to_image(x): def save_files(js_data, images, do_make_zip, index): - import csv + import csv filenames = [] fullfns = [] @@ -551,7 +551,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) - + with gr.Row(): download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) @@ -739,7 +739,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) - + with gr.Row(): download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) @@ -903,7 +903,15 @@ def create_ui(wrap_gradio_gpu_call): with gr.TabItem('Batch Process'): image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") - upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by'): + upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) + with gr.TabItem('Scale to'): + with gr.Group(): + with gr.Row(): + upscaling_resize_w = gr.Number(label="Width", value=512) + upscaling_resize_h = gr.Number(label="Height", value=512) + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) with gr.Group(): extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") @@ -934,6 +942,7 @@ def create_ui(wrap_gradio_gpu_call): fn=wrap_gradio_gpu_call(modules.extras.run_extras), _js="get_extras_tab_index", inputs=[ + dummy_component, dummy_component, extras_image, image_batch, @@ -941,6 +950,9 @@ def create_ui(wrap_gradio_gpu_call): codeformer_visibility, codeformer_weight, upscaling_resize, + upscaling_resize_w, + upscaling_resize_h, + upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, @@ -951,14 +963,14 @@ def create_ui(wrap_gradio_gpu_call): html_info, ] ) - + extras_send_to_img2img.click( fn=lambda x: image_from_url_text(x), _js="extract_image_from_gallery_img2img", inputs=[result_images], outputs=[init_img], ) - + extras_send_to_inpaint.click( fn=lambda x: image_from_url_text(x), _js="extract_image_from_gallery_img2img", @@ -1286,7 +1298,7 @@ Requested path was: {f} outputs=[], _js='function(){restart_reload()}' ) - + if column is not None: column.__exit__() @@ -1318,12 +1330,12 @@ Requested path was: {f} component_dict[k] = component settings_interface.gradio_ref = demo - + with gr.Tabs() as tabs: for interface, label, ifid in interfaces: with gr.TabItem(label, id=ifid): interface.render() - + if os.path.exists(os.path.join(script_path, "notification.mp3")): audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) @@ -1456,10 +1468,10 @@ Requested path was: {f} if getattr(obj,'custom_script_source',None) is not None: key = 'customscript/' + obj.custom_script_source + '/' + key - + if getattr(obj, 'do_not_save_to_config', False): return - + saved_value = ui_settings.get(key, None) if saved_value is None: ui_settings[key] = getattr(obj, field) @@ -1483,10 +1495,10 @@ Requested path was: {f} if type(x) == gr.Textbox: apply_field(x, 'value') - + if type(x) == gr.Number: apply_field(x, 'value') - + visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") visit(extras_interface, loadsave, "extras") -- cgit v1.2.3 From cc92dc1f8d73dd4d574c4c8ccab78b7fc61e440b Mon Sep 17 00:00:00 2001 From: ssysm Date: Sun, 9 Oct 2022 23:17:29 -0400 Subject: add vae path args --- modules/sd_models.py | 2 +- modules/shared.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index cb3982b1..b6979432 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -147,7 +147,7 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + vae_file = shared.cmd_opts.vae_path or os.path.splitext(checkpoint_file)[0] + ".vae.pt" if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location="cpu") diff --git a/modules/shared.py b/modules/shared.py index 2dc092d6..52ccfa6e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -64,7 +64,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) - +parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) cmd_opts = parser.parse_args() -- cgit v1.2.3 From 1f92336be768d235c18a82acb2195b7135101ae7 Mon Sep 17 00:00:00 2001 From: JC_Array Date: Sun, 9 Oct 2022 23:58:18 -0500 Subject: refactored the deepbooru module to improve speed on running multiple interogations in a row. Added the option to generate deepbooru tags for textual inversion preproccessing. --- modules/deepbooru.py | 84 +++++++++++++++++++++++++-------- modules/textual_inversion/preprocess.py | 22 ++++++++- modules/ui.py | 52 ++++++++++++++------ 3 files changed, 122 insertions(+), 36 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 7e3c0618..cee4a3b4 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -1,21 +1,74 @@ import os.path from concurrent.futures import ProcessPoolExecutor -from multiprocessing import get_context +import multiprocessing -def _load_tf_and_return_tags(pil_image, threshold): +def get_deepbooru_tags(pil_image, threshold=0.5): + """ + This method is for running only one image at a time for simple use. Used to the img2img interrogate. + """ + from modules import shared # prevents circular reference + create_deepbooru_process(threshold) + shared.deepbooru_process_return["value"] = -1 + shared.deepbooru_process_queue.put(pil_image) + while shared.deepbooru_process_return["value"] == -1: + time.sleep(0.2) + release_process() + return ret + + +def deepbooru_process(queue, deepbooru_process_return, threshold): + model, tags = get_deepbooru_tags_model() + while True: # while process is running, keep monitoring queue for new image + pil_image = queue.get() + if pil_image == "QUIT": + break + else: + deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold) + + +def create_deepbooru_process(threshold=0.5): + """ + Creates deepbooru process. A queue is created to send images into the process. This enables multiple images + to be processed in a row without reloading the model or creating a new process. To return the data, a shared + dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned + to the dictionary and the method adding the image to the queue should wait for this value to be updated with + the tags. + """ + from modules import shared # prevents circular reference + shared.deepbooru_process_manager = multiprocessing.Manager() + shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() + shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() + shared.deepbooru_process_return["value"] = -1 + shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold)) + shared.deepbooru_process.start() + + +def release_process(): + """ + Stops the deepbooru process to return used memory + """ + from modules import shared # prevents circular reference + shared.deepbooru_process_queue.put("QUIT") + shared.deepbooru_process.join() + shared.deepbooru_process_queue = None + shared.deepbooru_process = None + shared.deepbooru_process_return = None + shared.deepbooru_process_manager = None + +def get_deepbooru_tags_model(): import deepdanbooru as dd import tensorflow as tf import numpy as np - this_folder = os.path.dirname(__file__) model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru')) if not os.path.exists(os.path.join(model_path, 'project.json')): # there is no point importing these every time import zipfile from basicsr.utils.download_util import load_file_from_url - load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", - model_path) + load_file_from_url( + r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", + model_path) with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref: zip_ref.extractall(model_path) os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) @@ -24,7 +77,13 @@ def _load_tf_and_return_tags(pil_image, threshold): model = dd.project.load_model_from_project( model_path, compile_model=True ) + return model, tags + +def get_deepbooru_tags_from_model(model, tags, pil_image, threshold=0.5): + import deepdanbooru as dd + import tensorflow as tf + import numpy as np width = model.input_shape[2] height = model.input_shape[1] image = np.array(pil_image) @@ -57,17 +116,4 @@ def _load_tf_and_return_tags(pil_image, threshold): print('\n'.join(sorted(result_tags_print, reverse=True))) - return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') - - -def subprocess_init_no_cuda(): - import os - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" - - -def get_deepbooru_tags(pil_image, threshold=0.5): - context = get_context('spawn') - with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor: - f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, ) - ret = f.result() # will rethrow any exceptions - return ret \ No newline at end of file + return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') \ No newline at end of file diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index f1c002a2..9f63c9a4 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -3,11 +3,14 @@ from PIL import Image, ImageOps import platform import sys import tqdm +import time from modules import shared, images +from modules.shared import opts, cmd_opts +if cmd_opts.deepdanbooru: + import modules.deepbooru as deepbooru - -def preprocess(process_src, process_dst, process_flip, process_split, process_caption): +def preprocess(process_src, process_dst, process_flip, process_split, process_caption, process_caption_deepbooru=False): size = 512 src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) @@ -24,10 +27,21 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca if process_caption: shared.interrogator.load() + if process_caption_deepbooru: + deepbooru.create_deepbooru_process() + def save_pic_with_caption(image, index): if process_caption: caption = "-" + shared.interrogator.generate_caption(image) caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png") + elif process_caption_deepbooru: + shared.deepbooru_process_return["value"] = -1 + shared.deepbooru_process_queue.put(image) + while shared.deepbooru_process_return["value"] == -1: + time.sleep(0.2) + caption = "-" + shared.deepbooru_process_return["value"] + caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png") + shared.deepbooru_process_return["value"] = -1 else: caption = filename caption = os.path.splitext(caption)[0] @@ -79,6 +93,10 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca if process_caption: shared.interrogator.send_blip_to_ram() + if process_caption_deepbooru: + deepbooru.release_process() + + def sanitize_caption(base_path, original_caption, suffix): operating_system = platform.system().lower() if (operating_system == "windows"): diff --git a/modules/ui.py b/modules/ui.py index 2231a8ed..179e3a83 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1034,6 +1034,9 @@ def create_ui(wrap_gradio_gpu_call): process_flip = gr.Checkbox(label='Create flipped copies') process_split = gr.Checkbox(label='Split oversized images into two') process_caption = gr.Checkbox(label='Use BLIP caption as filename') + if cmd_opts.deepdanbooru: + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename') + with gr.Row(): with gr.Column(scale=3): @@ -1086,21 +1089,40 @@ def create_ui(wrap_gradio_gpu_call): ] ) - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_flip, - process_split, - process_caption, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) + if cmd_opts.deepdanbooru: + # if process_caption_deepbooru is None, it will cause an error, as a result only include it if it is enabled + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_flip, + process_split, + process_caption, + process_caption_deepbooru, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + else: + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_flip, + process_split, + process_caption, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) train_embedding.click( fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), -- cgit v1.2.3 From 8acc901ba3a252dc6ab4fabcb41644cf64d1774c Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 10 Oct 2022 00:38:55 -0400 Subject: Newer versions of PyTorch use TypedStorage instead Pytorch 1.13 and later will rename _TypedStorage to TypedStorage, so check for TypedStorage and use _TypedStorage if it is not available. Currently this is needed so that nightly builds of PyTorch work correctly. --- modules/safe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/safe.py b/modules/safe.py index 4d06f2a5..05917463 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -12,6 +12,10 @@ import _codecs import zipfile +# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage +TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage + + def encode(*args): out = _codecs.encode(*args) return out @@ -20,7 +24,7 @@ def encode(*args): class RestrictedUnpickler(pickle.Unpickler): def persistent_load(self, saved_id): assert saved_id[0] == 'storage' - return torch.storage._TypedStorage() + return TypedStorage() def find_class(self, module, name): if module == 'collections' and name == 'OrderedDict': -- cgit v1.2.3 From 8a7c07a2140c98bceca858087525d77fd0352fda Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 10 Oct 2022 15:39:39 +0800 Subject: show image history --- modules/images_history.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++ modules/ui.py | 11 ++++-- 2 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 modules/images_history.py (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py new file mode 100644 index 00000000..23d83557 --- /dev/null +++ b/modules/images_history.py @@ -0,0 +1,90 @@ +import os +def get_recent_images(is_img2img, dir_name, page_index, step): + page_index = int(page_index) + f_list = os.listdir(dir_name) + file_list = [] + for file in f_list: + if file[-4:] == ".txt": + continue + file_list.append(file) + file_list = sorted(file_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) + num = 24 + max_page_index = len(file_list) // num + 1 + page_index = max_page_index if page_index == -1 else page_index + step + page_index = 1 if page_index < 1 else page_index + page_index = max_page_index if page_index > max_page_index else page_index + idx_frm = (page_index - 1) * num + file_list = file_list[idx_frm:idx_frm + num] + print(f"Loading history page {page_index}") + return [os.path.join(dir_name, file) for file in file_list], page_index, file_list +def first_page_click(is_img2img, dir_name): + return get_recent_images(is_img2img, dir_name, 1, 0) +def end_page_click(is_img2img, dir_name): + return get_recent_images(is_img2img, dir_name, -1, 0) +def prev_page_click(is_img2img, dir_name, page_index): + return get_recent_images(is_img2img, dir_name, page_index, -1) +def next_page_click(is_img2img, dir_name, page_index): + return get_recent_images(is_img2img, dir_name, page_index, 1) +def page_index_change(is_img2img, dir_name, page_index): + return get_recent_images(is_img2img, dir_name, page_index, 0) +def show_image_info(num, filenames): + return filenames[int(num)] +def delete_image(is_img2img, dir_name, name, page_index, filenames): + path = os.path.join(dir_name, name) + if os.path.exists(path): + print(f"Delete file {path}") + os.remove(path) + i = 0 + for f in filenames: + if f == name: + break + i += 1 + images, page_index, file_list = get_recent_images(is_img2img, dir_name, page_index, 0) + current_file = file_list[i] if i < len(file_list) else None + return images, page_index, file_list, current_file + + +def show_images_history(gr, opts, is_img2img): + def id_name(is_img2img, name): + return ("img2img" if is_img2img else "txt2img") + "_" + name + with gr.Row(): + if is_img2img: + dir_name = opts.outdir_img2img_samples + else: + dir_name = opts.outdir_txt2img_samples + first_page = gr.Button('First Page', elem_id=id_name(is_img2img,"images_history_first_page")) + prev_page = gr.Button('Prev Page') + page_index = gr.Number(value=1) + next_page = gr.Button('Next Page') + end_page = gr.Button('End Page') + with gr.Row(): + delete = gr.Button('Delete') + Send = gr.Button('Send') + with gr.Row(): + with gr.Column(elem_id=id_name(is_img2img,"images_history")): + history_gallery = gr.Gallery(label="Images history").style(grid=6) + img_file_name = gr.Textbox() + img_file_info = gr.Textbox(dir_name) + img_path = gr.Textbox(dir_name, visible=False) + set_index = gr.Button('set_index', elem_id=id_name(is_img2img,"images_history_set_index")) + is_img2img_flag = gr.Checkbox(is_img2img, visible=False) + filenames = gr.State() + first_page.click(first_page_click, inputs=[is_img2img_flag, img_path], outputs=[history_gallery, page_index, filenames]) + next_page.click(next_page_click, inputs=[is_img2img_flag, img_path, page_index], outputs=[history_gallery, page_index, filenames]) + prev_page.click(prev_page_click, inputs=[is_img2img_flag, img_path, page_index], outputs=[history_gallery, page_index, filenames]) + end_page.click(end_page_click, inputs=[is_img2img_flag, img_path], outputs=[history_gallery, page_index, filenames]) + page_index.submit(page_index_change, inputs=[is_img2img_flag, img_path, page_index], outputs=[history_gallery, page_index, filenames]) + set_index.click(show_image_info, _js="images_history_get_current_img",inputs=[is_img2img_flag, filenames], outputs=img_file_name) + delete.click(delete_image, inputs=[is_img2img_flag, img_path, img_file_name, page_index, filenames], outputs=[history_gallery, page_index, filenames,img_file_name]) + #page_index.change(page_index_change, inputs=[is_img2img_flag, img_path, page_index], outputs=[history_gallery, page_index]) + +def create_history_tabs(gr, opts): + with gr.Blocks(analytics_enabled=False) as images_history: + with gr.Tabs() as tabs: + with gr.Tab("txt2img history", id="images_history_txt2img"): + with gr.Blocks(analytics_enabled=False) as images_history_txt2img: + show_images_history(gr, opts, is_img2img=False) + with gr.Tab("img2img history", id="images_history_img2img"): + with gr.Blocks(analytics_enabled=False) as images_history_img2img: + show_images_history(gr, opts, is_img2img=True) + return images_history diff --git a/modules/ui.py b/modules/ui.py index 4f18126f..8762fcf5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -37,6 +37,7 @@ import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui +import modules.images_history as img_his # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() @@ -499,7 +500,6 @@ def create_ui(wrap_gradio_gpu_call): custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) with gr.Column(variant='panel'): - with gr.Group(): txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4) @@ -516,6 +516,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): html_info = gr.HTML() generation_info = gr.Textbox(visible=False) + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -607,6 +608,7 @@ def create_ui(wrap_gradio_gpu_call): ] modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) + with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) @@ -696,6 +698,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): html_info = gr.HTML() generation_info = gr.Textbox(visible=False) + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -1126,8 +1129,10 @@ def create_ui(wrap_gradio_gpu_call): opts.save(shared.config_filename) - return f'{changed} settings changed.', opts.dumpjson() + return f'{changed} settings changed.', opts.dumpjson() + + images_history = img_his.create_history_tabs(gr, opts) with gr.Blocks(analytics_enabled=False) as settings_interface: settings_submit = gr.Button(value="Apply settings", variant='primary') result = gr.HTML() @@ -1206,7 +1211,9 @@ def create_ui(wrap_gradio_gpu_call): (pnginfo_interface, "PNG Info", "pnginfo"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (textual_inversion_interface, "Textual inversion", "ti"), + (images_history, "History", "images_history"), (settings_interface, "Settings", "settings"), + ] with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: -- cgit v1.2.3 From 3110f895b2718a3a25aae419fdf5c87c177ec9f4 Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Mon, 10 Oct 2022 17:07:46 +0900 Subject: Textual Inversion: Added custom training image size and number of repeats per input image in a single epoch --- modules/textual_inversion/dataset.py | 6 +++--- modules/textual_inversion/preprocess.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 15 ++++++++++++--- modules/ui.py | 8 +++++++- 4 files changed, 24 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 7c44ea5b..acc4ce59 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -15,13 +15,13 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): - def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None): + def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): self.placeholder_token = placeholder_token self.size = size - self.width = width - self.height = height + self.width = size + self.height = size self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index f1c002a2..b3de6fd7 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -7,8 +7,8 @@ import tqdm from modules import shared, images -def preprocess(process_src, process_dst, process_flip, process_split, process_caption): - size = 512 +def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption): + size = process_size src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index cd9f3498..e34dc2e8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -6,6 +6,7 @@ import torch import tqdm import html import datetime +import math from modules import shared, devices, sd_hijack, processing, sd_models @@ -156,7 +157,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -182,7 +183,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -200,6 +201,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, if ititial_step > steps: return embedding, filename + tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]) + epoch_len = (tr_img_len * num_repeats) + tr_img_len + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) for i, (x, text) in pbar: embedding.step = i + ititial_step @@ -223,7 +227,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, loss.backward() optimizer.step() - pbar.set_description(f"loss: {losses.mean():.7f}") + epoch_num = math.floor(embedding.step / epoch_len) + epoch_step = embedding.step - (epoch_num * epoch_len) + + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') @@ -236,6 +243,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, sd_model=shared.sd_model, prompt=text, steps=20, + height=training_size, + width=training_size, do_not_save_grid=True, do_not_save_samples=True, ) diff --git a/modules/ui.py b/modules/ui.py index 2231a8ed..f821fd8d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1029,6 +1029,7 @@ def create_ui(wrap_gradio_gpu_call): process_src = gr.Textbox(label='Source directory') process_dst = gr.Textbox(label='Destination directory') + process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') @@ -1043,13 +1044,15 @@ def create_ui(wrap_gradio_gpu_call): run_preprocess = gr.Button(value="Preprocess", variant='primary') with gr.Group(): - gr.HTML(value="

Train an embedding; must specify a directory with a set of 512x512 images

") + gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) learn_rate = gr.Number(label='Learning rate', value=5.0e-03) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) + num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) @@ -1092,6 +1095,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ process_src, process_dst, + process_size, process_flip, process_split, process_caption, @@ -1110,7 +1114,9 @@ def create_ui(wrap_gradio_gpu_call): learn_rate, dataset_directory, log_directory, + training_size, steps, + num_repeats, create_image_every, save_embedding_every, template_file, -- cgit v1.2.3 From 8ec069e64df48f8f202f8b93a08e91b69448eb39 Mon Sep 17 00:00:00 2001 From: JC_Array Date: Mon, 10 Oct 2022 03:23:24 -0500 Subject: removed duplicate run_preprocess.click by creating run_preprocess_inputs list and appending deepbooru variable to input list if in scope --- modules/ui.py | 49 +++++++++++++++++-------------------------------- 1 file changed, 17 insertions(+), 32 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 179e3a83..22ca74c2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1089,40 +1089,25 @@ def create_ui(wrap_gradio_gpu_call): ] ) + run_preprocess_inputs = [ + process_src, + process_dst, + process_flip, + process_split, + process_caption, + ] if cmd_opts.deepdanbooru: # if process_caption_deepbooru is None, it will cause an error, as a result only include it if it is enabled - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - else: - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_flip, - process_split, - process_caption, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) + run_preprocess_inputs.append(process_caption_deepbooru) + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=run_preprocess_inputs, + outputs=[ + ti_output, + ti_outcome, + ], + ) train_embedding.click( fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), -- cgit v1.2.3 From 4ee7519fc2e459ce8eff1f61f1655afba393357c Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Mon, 10 Oct 2022 17:31:33 +0900 Subject: Fixed progress bar output for epoch --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e34dc2e8..769682ea 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -228,7 +228,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini optimizer.step() epoch_num = math.floor(embedding.step / epoch_len) - epoch_step = embedding.step - (epoch_num * epoch_len) + epoch_step = embedding.step - (epoch_num * epoch_len) + 1 pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") -- cgit v1.2.3 From 2f94331df2cb1181439adecc28cfd758049f6501 Mon Sep 17 00:00:00 2001 From: JC_Array Date: Mon, 10 Oct 2022 03:34:00 -0500 Subject: removed change in last commit, simplified to adding the visible argument to process_caption_deepbooru and it set to False if deepdanbooru argument is not set --- modules/ui.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 22ca74c2..f8adafb3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1036,7 +1036,8 @@ def create_ui(wrap_gradio_gpu_call): process_caption = gr.Checkbox(label='Use BLIP caption as filename') if cmd_opts.deepdanbooru: process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename') - + else: + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename', visible=False) with gr.Row(): with gr.Column(scale=3): @@ -1089,20 +1090,17 @@ def create_ui(wrap_gradio_gpu_call): ] ) - run_preprocess_inputs = [ - process_src, - process_dst, - process_flip, - process_split, - process_caption, - ] - if cmd_opts.deepdanbooru: - # if process_caption_deepbooru is None, it will cause an error, as a result only include it if it is enabled - run_preprocess_inputs.append(process_caption_deepbooru) run_preprocess.click( fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", - inputs=run_preprocess_inputs, + inputs=[ + process_src, + process_dst, + process_flip, + process_split, + process_caption, + process_caption_deepbooru + ], outputs=[ ti_output, ti_outcome, -- cgit v1.2.3 From 23f2989799ee3911d2959cfceb74b921f20c9a51 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 10 Oct 2022 18:33:49 +0800 Subject: images history over --- modules/images_history.py | 141 +++++++++++++++++++++++++++------------------- modules/ui.py | 9 ++- 2 files changed, 92 insertions(+), 58 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 23d83557..0e0a48f3 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -1,5 +1,6 @@ import os -def get_recent_images(is_img2img, dir_name, page_index, step): +def get_recent_images(dir_name, page_index, step, image_index): + print(image_index) page_index = int(page_index) f_list = os.listdir(dir_name) file_list = [] @@ -8,7 +9,7 @@ def get_recent_images(is_img2img, dir_name, page_index, step): continue file_list.append(file) file_list = sorted(file_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) - num = 24 + num = 48 max_page_index = len(file_list) // num + 1 page_index = max_page_index if page_index == -1 else page_index + step page_index = 1 if page_index < 1 else page_index @@ -16,75 +17,101 @@ def get_recent_images(is_img2img, dir_name, page_index, step): idx_frm = (page_index - 1) * num file_list = file_list[idx_frm:idx_frm + num] print(f"Loading history page {page_index}") - return [os.path.join(dir_name, file) for file in file_list], page_index, file_list -def first_page_click(is_img2img, dir_name): - return get_recent_images(is_img2img, dir_name, 1, 0) -def end_page_click(is_img2img, dir_name): - return get_recent_images(is_img2img, dir_name, -1, 0) -def prev_page_click(is_img2img, dir_name, page_index): - return get_recent_images(is_img2img, dir_name, page_index, -1) -def next_page_click(is_img2img, dir_name, page_index): - return get_recent_images(is_img2img, dir_name, page_index, 1) -def page_index_change(is_img2img, dir_name, page_index): - return get_recent_images(is_img2img, dir_name, page_index, 0) -def show_image_info(num, filenames): - return filenames[int(num)] -def delete_image(is_img2img, dir_name, name, page_index, filenames): + image_index = int(image_index) + if image_index < 0 or image_index > len(file_list) - 1: + current_file = None + hide_image = None + else: + current_file = file_list[int(image_index)] + hide_image = os.path.join(dir_name, current_file) + return [os.path.join(dir_name, file) for file in file_list], page_index, file_list, current_file, hide_image +def first_page_click(dir_name, page_index, image_index): + return get_recent_images(dir_name, 1, 0, image_index) +def end_page_click(dir_name, page_index, image_index): + return get_recent_images(dir_name, -1, 0, image_index) +def prev_page_click(dir_name, page_index, image_index): + return get_recent_images(dir_name, page_index, -1, image_index) +def next_page_click(dir_name, page_index, image_index): + return get_recent_images(dir_name, page_index, 1, image_index) +def page_index_change(dir_name, page_index, image_index): + return get_recent_images(dir_name, page_index, 0, image_index) + +def show_image_info(num, image_path, filenames): + file = filenames[int(num)] + return file, num, os.path.join(image_path, file) +def delete_image(is_img2img, dir_name, name, page_index, filenames, image_index): + print("filename", name) path = os.path.join(dir_name, name) if os.path.exists(path): print(f"Delete file {path}") os.remove(path) - i = 0 - for f in filenames: - if f == name: - break - i += 1 - images, page_index, file_list = get_recent_images(is_img2img, dir_name, page_index, 0) - current_file = file_list[i] if i < len(file_list) else None - return images, page_index, file_list, current_file + images, page_index, file_list, current_file, hide_image = get_recent_images(dir_name, page_index, 0, image_index) + return images, page_index, file_list, current_file, hide_image -def show_images_history(gr, opts, is_img2img): +def show_images_history(gr, opts, is_img2img, run_pnginfo, switch_dict): def id_name(is_img2img, name): return ("img2img" if is_img2img else "txt2img") + "_" + name - with gr.Row(): - if is_img2img: - dir_name = opts.outdir_img2img_samples - else: - dir_name = opts.outdir_txt2img_samples - first_page = gr.Button('First Page', elem_id=id_name(is_img2img,"images_history_first_page")) - prev_page = gr.Button('Prev Page') - page_index = gr.Number(value=1) - next_page = gr.Button('Next Page') - end_page = gr.Button('End Page') - with gr.Row(): - delete = gr.Button('Delete') - Send = gr.Button('Send') - with gr.Row(): - with gr.Column(elem_id=id_name(is_img2img,"images_history")): - history_gallery = gr.Gallery(label="Images history").style(grid=6) - img_file_name = gr.Textbox() - img_file_info = gr.Textbox(dir_name) - img_path = gr.Textbox(dir_name, visible=False) - set_index = gr.Button('set_index', elem_id=id_name(is_img2img,"images_history_set_index")) - is_img2img_flag = gr.Checkbox(is_img2img, visible=False) - filenames = gr.State() - first_page.click(first_page_click, inputs=[is_img2img_flag, img_path], outputs=[history_gallery, page_index, filenames]) - next_page.click(next_page_click, inputs=[is_img2img_flag, img_path, page_index], outputs=[history_gallery, page_index, filenames]) - prev_page.click(prev_page_click, inputs=[is_img2img_flag, img_path, page_index], outputs=[history_gallery, page_index, filenames]) - end_page.click(end_page_click, inputs=[is_img2img_flag, img_path], outputs=[history_gallery, page_index, filenames]) - page_index.submit(page_index_change, inputs=[is_img2img_flag, img_path, page_index], outputs=[history_gallery, page_index, filenames]) - set_index.click(show_image_info, _js="images_history_get_current_img",inputs=[is_img2img_flag, filenames], outputs=img_file_name) - delete.click(delete_image, inputs=[is_img2img_flag, img_path, img_file_name, page_index, filenames], outputs=[history_gallery, page_index, filenames,img_file_name]) + if is_img2img: + dir_name = opts.outdir_img2img_samples + else: + dir_name = opts.outdir_txt2img_samples + with gr.Row(): + first_page = gr.Button('First', elem_id=id_name(is_img2img,"images_history_first_page")) + prev_page = gr.Button('Prev') + page_index = gr.Number(value=1, label="Page Index") + next_page = gr.Button('Next') + end_page = gr.Button('End') + with gr.Row(elem_id=id_name(is_img2img,"images_history")): + with gr.Row(): + with gr.Column(): + history_gallery = gr.Gallery(show_label=False).style(grid=6) + with gr.Column(): + with gr.Row(): + delete = gr.Button('Delete') + pnginfo_send_to_txt2img = gr.Button('Send to txt2img') + pnginfo_send_to_img2img = gr.Button('Send to img2img') + with gr.Row(): + with gr.Column(): + img_file_info = gr.Textbox(dir_name, label="Generate Info") + img_file_name = gr.Textbox(label="File Name") + with gr.Row(): + # hiden items + img_path = gr.Textbox(dir_name, visible=False) + is_img2img_flag = gr.Checkbox(is_img2img, visible=False) + image_index = gr.Textbox(value=-1, visible=False) + set_index = gr.Button('set_index', elem_id=id_name(is_img2img,"images_history_set_index")) + filenames = gr.State() + hide_image = gr.Image(visible=False, type="pil") + info1 = gr.Textbox(visible=False) + info2 = gr.Textbox(visible=False) + + + # turn pages + gallery_inputs = [img_path, page_index, image_index] + gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hide_image] + first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) #page_index.change(page_index_change, inputs=[is_img2img_flag, img_path, page_index], outputs=[history_gallery, page_index]) + + #other funcitons + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[is_img2img_flag, img_path, filenames], outputs=[img_file_name, image_index, hide_image]) + delete.click(delete_image, inputs=[is_img2img_flag, img_path, img_file_name, page_index, filenames, image_index], outputs=gallery_outputs) + hide_image.change(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) + switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') + switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') + -def create_history_tabs(gr, opts): +def create_history_tabs(gr, opts, run_pnginfo, switch_dict): with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: with gr.Tab("txt2img history", id="images_history_txt2img"): with gr.Blocks(analytics_enabled=False) as images_history_txt2img: - show_images_history(gr, opts, is_img2img=False) + show_images_history(gr, opts, False, run_pnginfo, switch_dict) with gr.Tab("img2img history", id="images_history_img2img"): with gr.Blocks(analytics_enabled=False) as images_history_img2img: - show_images_history(gr, opts, is_img2img=True) + show_images_history(gr, opts, True, run_pnginfo, switch_dict) return images_history diff --git a/modules/ui.py b/modules/ui.py index 8762fcf5..21c9236b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1131,8 +1131,15 @@ def create_ui(wrap_gradio_gpu_call): return f'{changed} settings changed.', opts.dumpjson() + #images history + images_history_switch_dict = { + "fn":modules.generation_parameters_copypaste.connect_paste, + "t2i":txt2img_paste_fields, + "i2i":img2img_paste_fields + } + images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) - images_history = img_his.create_history_tabs(gr, opts) + with gr.Blocks(analytics_enabled=False) as settings_interface: settings_submit = gr.Button(value="Apply settings", variant='primary') result = gr.HTML() -- cgit v1.2.3 From 7349088d32b080f64058b6e5de5f0380a71ecd09 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 16:11:14 +0300 Subject: --no-half-vae --- modules/devices.py | 6 +++++- modules/processing.py | 11 +++++++++-- modules/sd_models.py | 3 +++ modules/sd_samplers.py | 4 ++-- modules/shared.py | 1 + 5 files changed, 20 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 0158b11f..03ef58f1 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32") device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() dtype = torch.float16 +dtype_vae = torch.float16 def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. @@ -59,9 +60,12 @@ def randn_without_seed(shape): return torch.randn(shape, device=device) -def autocast(): +def autocast(disable=False): from modules import shared + if disable: + return contextlib.nullcontext() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 94d2dd62..ec8651ae 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -259,6 +259,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x +def decode_first_stage(model, x): + with devices.autocast(disable=x.dtype == devices.dtype_vae): + x = model.decode_first_stage(x) + + return x + + def get_fixed_seed(seed): if seed is None or seed == '' or seed == -1: return int(random.randrange(4294967294)) @@ -400,7 +407,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: samples_ddim = samples_ddim.to(devices.dtype) - x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) + x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) del samples_ddim @@ -533,7 +540,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.scale_latent: samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") else: - decoded_samples = self.sd_model.decode_first_stage(samples) + decoded_samples = decode_first_stage(self.sd_model, samples) if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") diff --git a/modules/sd_models.py b/modules/sd_models.py index e63d3c29..2cdcd84f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -149,6 +149,7 @@ def load_model_weights(model, checkpoint_info): model.half() devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" if os.path.exists(vae_file): @@ -158,6 +159,8 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.load_state_dict(vae_dict) + model.first_stage_model.to(devices.dtype_vae) + model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 6e743f7e..d168b938 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,7 +7,7 @@ import inspect import k_diffusion.sampling import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from modules import prompt_parser +from modules import prompt_parser, devices, processing from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None): def sample_to_image(samples): - x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0] + x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0] x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) diff --git a/modules/shared.py b/modules/shared.py index 1995a99a..5dfc344c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -25,6 +25,7 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") +parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") -- cgit v1.2.3 From 04c745ea4f81518999927fee5f78500560c25e29 Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Mon, 10 Oct 2022 22:35:35 +0900 Subject: Custom Width and Height --- modules/textual_inversion/dataset.py | 7 +++---- modules/textual_inversion/preprocess.py | 19 ++++++++++--------- modules/textual_inversion/textual_inversion.py | 11 +++++------ modules/ui.py | 12 ++++++++---- 4 files changed, 26 insertions(+), 23 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index acc4ce59..bcf772d2 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -15,13 +15,12 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): - def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): self.placeholder_token = placeholder_token - self.size = size - self.width = size - self.height = size + self.width = width + self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index b3de6fd7..d7efdef2 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -7,8 +7,9 @@ import tqdm from modules import shared, images -def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption): - size = process_size +def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption): + width = process_width + height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) @@ -55,23 +56,23 @@ def preprocess(process_src, process_dst, process_size, process_flip, process_spl is_wide = ratio < 1 / 1.35 if process_split and is_tall: - img = img.resize((size, size * img.height // img.width)) + img = img.resize((width, height * img.height // img.width)) - top = img.crop((0, 0, size, size)) + top = img.crop((0, 0, width, height)) save_pic(top, index) - bot = img.crop((0, img.height - size, size, img.height)) + bot = img.crop((0, img.height - height, width, img.height)) save_pic(bot, index) elif process_split and is_wide: - img = img.resize((size * img.width // img.height, size)) + img = img.resize((width * img.width // img.height, height)) - left = img.crop((0, 0, size, size)) + left = img.crop((0, 0, width, height)) save_pic(left, index) - right = img.crop((img.width - size, 0, img.width, size)) + right = img.crop((img.width - width, 0, img.width, height)) save_pic(right, index) else: - img = images.resize_image(1, img, size, size) + img = images.resize_image(1, img, width, height) save_pic(img, index) shared.state.nextjob() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 769682ea..5965c5a0 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -6,7 +6,6 @@ import torch import tqdm import html import datetime -import math from modules import shared, devices, sd_hijack, processing, sd_models @@ -157,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -183,7 +182,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -227,7 +226,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini loss.backward() optimizer.step() - epoch_num = math.floor(embedding.step / epoch_len) + epoch_num = embedding.step // epoch_len epoch_step = embedding.step - (epoch_num * epoch_len) + 1 pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") @@ -243,8 +242,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini sd_model=shared.sd_model, prompt=text, steps=20, - height=training_size, - width=training_size, + height=training_height, + width=training_width, do_not_save_grid=True, do_not_save_samples=True, ) diff --git a/modules/ui.py b/modules/ui.py index f821fd8d..8c06ad7c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1029,7 +1029,8 @@ def create_ui(wrap_gradio_gpu_call): process_src = gr.Textbox(label='Source directory') process_dst = gr.Textbox(label='Destination directory') - process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) + process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') @@ -1050,7 +1051,8 @@ def create_ui(wrap_gradio_gpu_call): dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) - training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) + training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) @@ -1095,7 +1097,8 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ process_src, process_dst, - process_size, + process_width, + process_height, process_flip, process_split, process_caption, @@ -1114,7 +1117,8 @@ def create_ui(wrap_gradio_gpu_call): learn_rate, dataset_directory, log_directory, - training_size, + training_width, + training_height, steps, num_repeats, create_image_every, -- cgit v1.2.3 From 8f1efdc130cf7ff47cb8d3722cdfc0dbeba3069e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 17:03:45 +0300 Subject: --no-half-vae pt2 --- modules/processing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index ec8651ae..50ba4fc5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -405,8 +405,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent - samples_ddim = samples_ddim.to(devices.dtype) - + samples_ddim = samples_ddim.to(devices.dtype_vae) x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) -- cgit v1.2.3 From ea00c1624bbb0dcb5be07f59c9509061baddf5b1 Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Mon, 10 Oct 2022 17:07:46 +0900 Subject: Textual Inversion: Added custom training image size and number of repeats per input image in a single epoch --- modules/textual_inversion/dataset.py | 6 +++--- modules/textual_inversion/preprocess.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 15 ++++++++++++--- modules/ui.py | 8 +++++++- 4 files changed, 24 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 7c44ea5b..acc4ce59 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -15,13 +15,13 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): - def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None): + def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): self.placeholder_token = placeholder_token self.size = size - self.width = width - self.height = height + self.width = size + self.height = size self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index f1c002a2..b3de6fd7 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -7,8 +7,8 @@ import tqdm from modules import shared, images -def preprocess(process_src, process_dst, process_flip, process_split, process_caption): - size = 512 +def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption): + size = process_size src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index cd9f3498..e34dc2e8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -6,6 +6,7 @@ import torch import tqdm import html import datetime +import math from modules import shared, devices, sd_hijack, processing, sd_models @@ -156,7 +157,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -182,7 +183,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -200,6 +201,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, if ititial_step > steps: return embedding, filename + tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]) + epoch_len = (tr_img_len * num_repeats) + tr_img_len + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) for i, (x, text) in pbar: embedding.step = i + ititial_step @@ -223,7 +227,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, loss.backward() optimizer.step() - pbar.set_description(f"loss: {losses.mean():.7f}") + epoch_num = math.floor(embedding.step / epoch_len) + epoch_step = embedding.step - (epoch_num * epoch_len) + + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') @@ -236,6 +243,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, sd_model=shared.sd_model, prompt=text, steps=20, + height=training_size, + width=training_size, do_not_save_grid=True, do_not_save_samples=True, ) diff --git a/modules/ui.py b/modules/ui.py index 2231a8ed..f821fd8d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1029,6 +1029,7 @@ def create_ui(wrap_gradio_gpu_call): process_src = gr.Textbox(label='Source directory') process_dst = gr.Textbox(label='Destination directory') + process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') @@ -1043,13 +1044,15 @@ def create_ui(wrap_gradio_gpu_call): run_preprocess = gr.Button(value="Preprocess", variant='primary') with gr.Group(): - gr.HTML(value="

Train an embedding; must specify a directory with a set of 512x512 images

") + gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) learn_rate = gr.Number(label='Learning rate', value=5.0e-03) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) + num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) @@ -1092,6 +1095,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ process_src, process_dst, + process_size, process_flip, process_split, process_caption, @@ -1110,7 +1114,9 @@ def create_ui(wrap_gradio_gpu_call): learn_rate, dataset_directory, log_directory, + training_size, steps, + num_repeats, create_image_every, save_embedding_every, template_file, -- cgit v1.2.3 From 6ad3a53e368d36535de1a4fca73b3bb78fd40654 Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Mon, 10 Oct 2022 17:31:33 +0900 Subject: Fixed progress bar output for epoch --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e34dc2e8..769682ea 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -228,7 +228,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini optimizer.step() epoch_num = math.floor(embedding.step / epoch_len) - epoch_step = embedding.step - (epoch_num * epoch_len) + epoch_step = embedding.step - (epoch_num * epoch_len) + 1 pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") -- cgit v1.2.3 From 7a20f914eddfdf09c0ccced157ec108205bc3d0f Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Mon, 10 Oct 2022 22:35:35 +0900 Subject: Custom Width and Height --- modules/textual_inversion/dataset.py | 7 +++---- modules/textual_inversion/preprocess.py | 19 ++++++++++--------- modules/textual_inversion/textual_inversion.py | 11 +++++------ modules/ui.py | 12 ++++++++---- 4 files changed, 26 insertions(+), 23 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index acc4ce59..bcf772d2 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -15,13 +15,12 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): - def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): self.placeholder_token = placeholder_token - self.size = size - self.width = size - self.height = size + self.width = width + self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index b3de6fd7..d7efdef2 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -7,8 +7,9 @@ import tqdm from modules import shared, images -def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption): - size = process_size +def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption): + width = process_width + height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) @@ -55,23 +56,23 @@ def preprocess(process_src, process_dst, process_size, process_flip, process_spl is_wide = ratio < 1 / 1.35 if process_split and is_tall: - img = img.resize((size, size * img.height // img.width)) + img = img.resize((width, height * img.height // img.width)) - top = img.crop((0, 0, size, size)) + top = img.crop((0, 0, width, height)) save_pic(top, index) - bot = img.crop((0, img.height - size, size, img.height)) + bot = img.crop((0, img.height - height, width, img.height)) save_pic(bot, index) elif process_split and is_wide: - img = img.resize((size * img.width // img.height, size)) + img = img.resize((width * img.width // img.height, height)) - left = img.crop((0, 0, size, size)) + left = img.crop((0, 0, width, height)) save_pic(left, index) - right = img.crop((img.width - size, 0, img.width, size)) + right = img.crop((img.width - width, 0, img.width, height)) save_pic(right, index) else: - img = images.resize_image(1, img, size, size) + img = images.resize_image(1, img, width, height) save_pic(img, index) shared.state.nextjob() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 769682ea..5965c5a0 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -6,7 +6,6 @@ import torch import tqdm import html import datetime -import math from modules import shared, devices, sd_hijack, processing, sd_models @@ -157,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -183,7 +182,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -227,7 +226,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini loss.backward() optimizer.step() - epoch_num = math.floor(embedding.step / epoch_len) + epoch_num = embedding.step // epoch_len epoch_step = embedding.step - (epoch_num * epoch_len) + 1 pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") @@ -243,8 +242,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini sd_model=shared.sd_model, prompt=text, steps=20, - height=training_size, - width=training_size, + height=training_height, + width=training_width, do_not_save_grid=True, do_not_save_samples=True, ) diff --git a/modules/ui.py b/modules/ui.py index f821fd8d..8c06ad7c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1029,7 +1029,8 @@ def create_ui(wrap_gradio_gpu_call): process_src = gr.Textbox(label='Source directory') process_dst = gr.Textbox(label='Destination directory') - process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) + process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') @@ -1050,7 +1051,8 @@ def create_ui(wrap_gradio_gpu_call): dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) - training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512) + training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) @@ -1095,7 +1097,8 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ process_src, process_dst, - process_size, + process_width, + process_height, process_flip, process_split, process_caption, @@ -1114,7 +1117,8 @@ def create_ui(wrap_gradio_gpu_call): learn_rate, dataset_directory, log_directory, - training_size, + training_width, + training_height, steps, num_repeats, create_image_every, -- cgit v1.2.3 From 707a431100362645e914042bb344d08439f48ac8 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 15:34:49 +0100 Subject: add pixel data footer --- modules/textual_inversion/textual_inversion.py | 48 ++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7a24192e..6fb64691 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,6 +12,7 @@ from ..images import captionImageOverlay import numpy as np import base64 import json +import zlib from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset @@ -20,7 +21,7 @@ class EmbeddingEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, torch.Tensor): return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} - return json.JSONEncoder.default(self, o) + return json.JSONEncoder.default(self, obj) class EmbeddingDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): @@ -38,6 +39,45 @@ def embeddingFromB64(data): d = base64.b64decode(data) return json.loads(d,cls=EmbeddingDecoder) +def appendImageDataFooter(image,data): + d = 3 + data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9) + dnp = np.frombuffer(data_compressed,np.uint8).copy() + w = image.size[0] + next_size = dnp.shape[0] + (w-(dnp.shape[0]%w)) + next_size = next_size + ((w*d)-(next_size%(w*d))) + dnp.resize(next_size) + dnp = dnp.reshape((-1,w,d)) + print(dnp.shape) + im = Image.fromarray(dnp,mode='RGB') + background = Image.new('RGB',(image.size[0],image.size[1]+im.size[1]+1),(0,0,0)) + background.paste(image,(0,0)) + background.paste(im,(0,image.size[1]+1)) + return background + +def crop_black(img,tol=0): + mask = (img>tol).all(2) + mask0,mask1 = mask.any(0),mask.any(1) + col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax() + row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax() + return img[row_start:row_end,col_start:col_end] + +def extractImageDataFooter(image): + d=3 + outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) + lastRow = np.where( np.sum(outarr, axis=(1,2))==0) + if lastRow[0].shape[0] == 0: + print('Image data block not found.') + return None + lastRow = lastRow[0] + + lastRow = lastRow.max() + + dataBlock = outarr[lastRow+1::].astype(np.uint8).flatten().tobytes() + print(lastRow) + data = zlib.decompress(dataBlock) + return json.loads(data,cls=EmbeddingDecoder) + class Embedding: def __init__(self, vec, name, step=None): self.vec = vec @@ -113,6 +153,9 @@ class EmbeddingDatabase: if 'sd-ti-embedding' in embed_image.text: data = embeddingFromB64(embed_image.text['sd-ti-embedding']) name = data.get('name',name) + else: + data = extractImageDataFooter(embed_image) + name = data.get('name',name) else: data = torch.load(path, map_location="cpu") @@ -190,7 +233,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -308,6 +351,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini footer_right = '{}'.format(embedding.step) captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) + captioned_image = appendImageDataFooter(captioned_image,data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) -- cgit v1.2.3 From df6d0d9286279c41c4c67460c3158fa268697524 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 15:43:09 +0100 Subject: convert back to rgb as some hosts add alpha --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6fb64691..667a7cf2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -64,7 +64,7 @@ def crop_black(img,tol=0): def extractImageDataFooter(image): d=3 - outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) + outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) lastRow = np.where( np.sum(outarr, axis=(1,2))==0) if lastRow[0].shape[0] == 0: print('Image data block not found.') -- cgit v1.2.3 From f347ddfd808c56bb1bacdec0c4bedf826ff85cd8 Mon Sep 17 00:00:00 2001 From: RW21 Date: Mon, 10 Oct 2022 10:44:11 +0900 Subject: Remove max_batch_count from ui.py --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 8c06ad7c..8ba84911 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -524,7 +524,7 @@ def create_ui(wrap_gradio_gpu_call): denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) with gr.Row(): - batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) @@ -710,7 +710,7 @@ def create_ui(wrap_gradio_gpu_call): tiling = gr.Checkbox(label='Tiling', value=False) with gr.Row(): - batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) with gr.Group(): -- cgit v1.2.3 From b340439586d844e76782149ca1857c8de35773ec Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 10 Oct 2022 05:28:06 +0100 Subject: Unlimited Token Works Unlimited tokens actually work now. Works with textual inversion too. Replaces the previous not-so-much-working implementation. --- modules/sd_hijack.py | 69 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 23 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 437acce4..8d5c77d8 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -43,10 +43,7 @@ def undo_optimizations(): def get_target_prompt_token_count(token_count): - if token_count < 75: - return 75 - - return math.ceil(token_count / 10) * 10 + return math.ceil(max(token_count, 1) / 75) * 75 class StableDiffusionModelHijack: @@ -127,7 +124,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.token_mults[ident] = mult def tokenize_line(self, line, used_custom_terms, hijack_comments): - id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id if opts.enable_emphasis: @@ -154,7 +150,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): i += 1 else: emb_len = int(embedding.vec.shape[0]) - fixes.append((len(remade_tokens), embedding)) + iteration = len(remade_tokens) // 75 + fixes.append((iteration, (len(remade_tokens) % 75, embedding))) remade_tokens += [0] * emb_len multipliers += [weight] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) @@ -162,10 +159,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): token_count = len(remade_tokens) prompt_target_length = get_target_prompt_token_count(token_count) - tokens_to_add = prompt_target_length - len(remade_tokens) + 1 + tokens_to_add = prompt_target_length - len(remade_tokens) - remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add - multipliers = [1.0] + multipliers + [1.0] * tokens_to_add + remade_tokens = remade_tokens + [id_end] * tokens_to_add + multipliers = multipliers + [1.0] * tokens_to_add return remade_tokens, fixes, multipliers, token_count @@ -260,29 +257,55 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): hijack_fixes.append(fixes) batch_multipliers.append(multipliers) return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - + def forward(self, text): - - if opts.use_old_emphasis_implementation: + use_old = opts.use_old_emphasis_implementation + if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - self.hijack.fixes = hijack_fixes self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + if use_old: + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) + + z = None + i = 0 + while max(map(len, remade_batch_tokens)) != 0: + rem_tokens = [x[75:] for x in remade_batch_tokens] + rem_multipliers = [x[75:] for x in batch_multipliers] + + self.hijack.fixes = [] + for unfiltered in hijack_fixes: + fixes = [] + for fix in unfiltered: + if fix[0] == i: + fixes.append(fix[1]) + self.hijack.fixes.append(fixes) + + z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers]) + z = z1 if z is None else torch.cat((z, z1), axis=-2) + + remade_batch_tokens = rem_tokens + batch_multipliers = rem_multipliers + i += 1 + + return z + + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] + batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] + + tokens = torch.asarray(remade_batch_tokens).to(device) + outputs = self.wrapped.transformer(input_ids=tokens) - target_token_count = get_target_prompt_token_count(token_count) + 2 - - position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76] - position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) - - remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] - tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers) if opts.CLIP_stop_at_last_layers > 1: z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] z = self.wrapped.transformer.text_model.final_layer_norm(z) @@ -290,7 +313,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] + batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) -- cgit v1.2.3 From 460bbae58726c177beddfcddf351f27e205d3fb2 Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 10 Oct 2022 16:09:06 +0100 Subject: Pad beginning of textual inversion embedding --- modules/sd_hijack.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 8d5c77d8..3a60cd63 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -151,6 +151,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: emb_len = int(embedding.vec.shape[0]) iteration = len(remade_tokens) // 75 + if (len(remade_tokens) + emb_len) // 75 != iteration: + rem = (75 * (iteration + 1) - len(remade_tokens)) + remade_tokens += [id_end] * rem + multipliers += [1.0] * rem + iteration += 1 fixes.append((iteration, (len(remade_tokens) % 75, embedding))) remade_tokens += [0] * emb_len multipliers += [weight] * emb_len -- cgit v1.2.3 From d5c14365fd468dbf89fa12a68bea5b217077273c Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 10 Oct 2022 16:13:47 +0100 Subject: Add back in output hidden states parameter --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3a60cd63..3edc0e9d 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -309,7 +309,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] tokens = torch.asarray(remade_batch_tokens).to(device) - outputs = self.wrapped.transformer(input_ids=tokens) + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) if opts.CLIP_stop_at_last_layers > 1: z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] -- cgit v1.2.3 From 9d33baba587637815d818e5e641d8f8b74c4900d Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 10 Oct 2022 18:46:48 +0300 Subject: Always show previous mask and fix extras_send dest --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 8ba84911..e8039d76 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -961,7 +961,7 @@ def create_ui(wrap_gradio_gpu_call): extras_send_to_inpaint.click( fn=lambda x: image_from_url_text(x), - _js="extract_image_from_gallery_img2img", + _js="extract_image_from_gallery_inpaint", inputs=[result_images], outputs=[init_img_with_mask], ) -- cgit v1.2.3 From 623251ce2b8d152e242011f62984a8247a14a389 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 10 Oct 2022 17:45:38 +0300 Subject: allow pascal onwards --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3edc0e9d..827bf304 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -23,7 +23,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward -- cgit v1.2.3 From 3e7a981194ed9c454e951365846e4eba66fa7095 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 10 Oct 2022 17:51:05 +0300 Subject: remove functorch --- modules/sd_hijack_optimizations.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 634fb4b2..18408e62 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -13,8 +13,6 @@ from modules import shared if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: import xformers.ops - import functorch - xformers._is_functorch_available = True shared.xformers_available = True except Exception: print("Cannot import xformers", file=sys.stderr) -- cgit v1.2.3 From ece27fe98933eb0eda8ea94dc496dd7554f3a08f Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sun, 9 Oct 2022 18:55:33 +0300 Subject: Add files via upload --- modules/swinir_model_arch_v2.py | 1017 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 1017 insertions(+) create mode 100644 modules/swinir_model_arch_v2.py (limited to 'modules') diff --git a/modules/swinir_model_arch_v2.py b/modules/swinir_model_arch_v2.py new file mode 100644 index 00000000..0e28ae6e --- /dev/null +++ b/modules/swinir_model_arch_v2.py @@ -0,0 +1,1017 @@ +# ----------------------------------------------------------------------------------- +# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/ +# Written by Conde and Choi et al. +# ----------------------------------------------------------------------------------- + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., + pretrained_window_size=[0, 0]): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + # cosine attention + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, ' \ + f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pre-training. + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size)) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + #assert L == H * W, "input feature has wrong size" + + shortcut = x + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + x = shortcut + self.drop_path(self.norm1(x)) + + # FFN + x = x + self.drop_path(self.norm2(self.mlp(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + flops += H * W * self.dim // 2 + return flops + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + pretrained_window_size=0): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + +class Upsample_hf(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample_hf, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + + +class Swin2SR(nn.Module): + r""" Swin2SR + A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + **kwargs): + super(Swin2SR, self).__init__() + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers.append(layer) + + if self.upsampler == 'pixelshuffle_hf': + self.layers_hf = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + + ) + self.layers_hf.append(layer) + + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffle_aux': + self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_after_aux = nn.Sequential( + nn.Conv2d(3, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffle_hf': + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.upsample_hf = Upsample_hf(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_before_upsample_hf = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward_features_hf(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers_hf: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffle_aux': + bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) + bicubic = self.conv_bicubic(bicubic) + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + aux = self.conv_aux(x) # b, 3, LR_H, LR_W + x = self.conv_after_aux(aux) + x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] + x = self.conv_last(x) + aux = aux / self.img_range + self.mean + elif self.upsampler == 'pixelshuffle_hf': + # for classical SR with HF + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x_before = self.conv_before_upsample(x) + x_out = self.conv_last(self.upsample(x_before)) + + x_hf = self.conv_first_hf(x_before) + x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf + x_hf = self.conv_before_upsample_hf(x_hf) + x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) + x = x_out + x_hf + x_hf = x_hf / self.img_range + self.mean + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + if self.upsampler == "pixelshuffle_aux": + return x[:, :, :H*self.upscale, :W*self.upscale], aux + + elif self.upsampler == "pixelshuffle_hf": + x_out = x_out / self.img_range + self.mean + return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] + + else: + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = Swin2SR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) \ No newline at end of file -- cgit v1.2.3 From ed769977f0d0f201d8e361d365102f18775fc62c Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sun, 9 Oct 2022 18:56:59 +0300 Subject: add swinir v2 support --- modules/swinir_model.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/swinir_model.py b/modules/swinir_model.py index fbd11f84..baa02e3d 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -10,6 +10,7 @@ from tqdm import tqdm from modules import modelloader from modules.shared import cmd_opts, opts, device from modules.swinir_model_arch import SwinIR as net +from modules.swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData precision_scope = ( @@ -57,22 +58,42 @@ class UpscalerSwinIR(Upscaler): filename = path if filename is None or not os.path.exists(filename): return None - model = net( + if filename.endswith(".v2.pth"): + model = net2( upscale=scale, in_chans=3, img_size=64, window_size=8, img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], mlp_ratio=2, upsampler="nearest+conv", - resi_connection="3conv", - ) + resi_connection="1conv", + ) + params = None + else: + model = net( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], + embed_dim=240, + num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="3conv", + ) + params = "params_ema" pretrained_model = torch.load(filename) - model.load_state_dict(pretrained_model["params_ema"], strict=True) + if params is not None: + model.load_state_dict(pretrained_model[params], strict=True) + else: + model.load_state_dict(pretrained_model, strict=True) if not cmd_opts.no_half: model = model.half() return model -- cgit v1.2.3 From af62ad4d25dcd0454944368f4925d83101cdedbc Mon Sep 17 00:00:00 2001 From: ssysm Date: Mon, 10 Oct 2022 13:25:28 -0400 Subject: change vae loading method --- modules/sd_models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index b0e1d8bd..7a42d924 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -150,9 +150,16 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - vae_file = shared.cmd_opts.vae_path or os.path.splitext(checkpoint_file)[0] + ".vae.pt" + vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + if os.path.exists(vae_file): + print(f"Found VAE Weights: {vae_file}") + elif shared.cmd_opts.vae_path != None: + vae_file = shared.cmd_opts.vae_path + print(f'No VAE found for inside the model folder. Using CLI specified : {vae_file}') + else: + print("No VAE found for inside the model folder. Passing.") + if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location="cpu") vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} -- cgit v1.2.3 From 39919c40dd18f5a14ae21403efea1b0f819756c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 20:32:37 +0300 Subject: add eta noise seed delta option --- modules/processing.py | 6 +++++- modules/shared.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 50ba4fc5..698b3069 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -207,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see # enables the generation of additional tensors with noise that the sampler will use during its processing. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds: + if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] else: sampler_noises = None @@ -247,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see if sampler_noises is not None: cnt = p.sampler.number_of_needed_noises(p) + if opts.eta_noise_seed_delta > 0: + torch.manual_seed(seed + opts.eta_noise_seed_delta) + for j in range(cnt): sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) @@ -301,6 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, + "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, } generation_params.update(p.extra_generation_params) diff --git a/modules/shared.py b/modules/shared.py index 5dfc344c..b1c65ecf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -260,6 +260,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) -- cgit v1.2.3 From 727e4d108674dc2813507e2a973a733ef21e8d53 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 20:46:55 +0300 Subject: no to different messages plus fix using != to compare to None --- modules/sd_models.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 4c06051e..0a55b4c3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -152,15 +152,12 @@ def load_model_weights(model, checkpoint_info): devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - if os.path.exists(vae_file): - print(f"Found VAE Weights: {vae_file}") - elif shared.cmd_opts.vae_path != None: + + if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: vae_file = shared.cmd_opts.vae_path - print(f'No VAE found for inside the model folder. Using CLI specified : {vae_file}') - else: - print("No VAE found for inside the model folder. Passing.") if os.path.exists(vae_file): + print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location="cpu") vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} -- cgit v1.2.3 From 1d64976dbc5a0f3124567b91fadd5014a9d93c5f Mon Sep 17 00:00:00 2001 From: Justin Maier Date: Mon, 10 Oct 2022 12:04:21 -0600 Subject: Simplify crop logic --- modules/extras.py | 14 +++----------- modules/ui.py | 4 ++-- 2 files changed, 5 insertions(+), 13 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 83ca7049..b24d7de3 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -73,16 +73,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, crop_info = " (crop)" if upscaling_crop else "" info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - def crop_upscaled_center(image, resize_w, resize_h): - left = int(math.ceil((image.width - resize_w) / 2)) - right = image.width - int(math.floor((image.width - resize_w) / 2)) - top = int(math.ceil((image.height - resize_h) / 2)) - bottom = image.height - int(math.floor((image.height - resize_h) / 2)) - - image = image.crop((left, top, right, bottom)) - return image - - if upscaling_resize != 1.0: def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) @@ -94,7 +84,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, upscaler = shared.sd_upscalers[scaler_index] c = upscaler.scaler.upscale(image, resize, upscaler.data_path) if mode == 1 and crop: - c = crop_upscaled_center(c, resize_w, resize_h) + cropped = Image.new("RGB", (resize_w, resize_h)) + cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2)) + c = cropped cached_images[key] = c return c diff --git a/modules/ui.py b/modules/ui.py index 4bb2892b..1aabe18d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -909,8 +909,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.TabItem('Scale to'): with gr.Group(): with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512) - upscaling_resize_h = gr.Number(label="Height", value=512) + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) with gr.Group(): -- cgit v1.2.3 From bc3e183b739913e7be91213a256f038b10eb71e9 Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Tue, 11 Oct 2022 04:30:13 +0900 Subject: Textual Inversion: Preprocess and Training will only pick-up image files --- modules/textual_inversion/dataset.py | 3 ++- modules/textual_inversion/preprocess.py | 3 ++- modules/textual_inversion/textual_inversion.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index bcf772d2..d4baf066 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -22,6 +22,7 @@ class PersonalizedBase(Dataset): self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) + self.extns = [".jpg",".jpeg",".png"] self.dataset = [] @@ -32,7 +33,7 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' - self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in self.extns] print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): image = Image.open(path) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index d7efdef2..b6c78cf8 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,12 +12,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) + extns = [".jpg",".jpeg",".png"] assert src != dst, 'same directory specified as source and destination' os.makedirs(dst, exist_ok=True) - files = os.listdir(src) + files = [i for i in os.listdir(src) if os.path.splitext(i.casefold())[1] in extns] shared.state.textinfo = "Preprocessing..." shared.state.job_count = len(files) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5965c5a0..45397be9 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -161,6 +161,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps + extns = [".jpg",".jpeg",".png"] filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') @@ -200,7 +201,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if ititial_step > steps: return embedding, filename - tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]) + tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in extns]) epoch_len = (tr_img_len * num_repeats) + tr_img_len pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) -- cgit v1.2.3 From f98338faa84ecce503e68d8ba13d5f7bbae52730 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 23:15:48 +0300 Subject: add an option to not add watermark to created images --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index da389f9c..ecd15ef5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -173,6 +173,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), + "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { -- cgit v1.2.3 From 2536ecbb1790da2af0d61b6a26f38732cba665cd Mon Sep 17 00:00:00 2001 From: Fampai <> Date: Mon, 10 Oct 2022 17:10:29 -0400 Subject: Refactored learning rate code --- modules/textual_inversion/textual_inversion.py | 51 ++++++++++++++++++++++++-- modules/ui.py | 2 +- 2 files changed, 48 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5965c5a0..c64a4598 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -189,8 +189,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini embedding = hijack.embedding_db.word_embeddings[embedding_name] embedding.vec.requires_grad = True - optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) - losses = torch.zeros((32,)) last_saved_file = "" @@ -203,12 +201,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]) epoch_len = (tr_img_len * num_repeats) + tr_img_len + scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step)) + (learn_rate, end_step) = next(scheduleIter) + print(f'Training at rate of {learn_rate} until step {end_step}') + + optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) for i, (x, text) in pbar: embedding.step = i + ititial_step - if embedding.step > steps: - break + if embedding.step > end_step: + try: + (learn_rate, end_step) = next(scheduleIter) + except: + break + tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') + for pg in optimizer.param_groups: + pg['lr'] = learn_rate if shared.state.interrupted: break @@ -277,3 +287,36 @@ Last saved image: {html.escape(last_saved_image)}
return embedding, filename +class LearnSchedule: + def __init__(self, learn_rate, max_steps, cur_step=0): + pairs = learn_rate.split(',') + self.rates = [] + self.it = 0 + self.maxit = 0 + for i, pair in enumerate(pairs): + tmp = pair.split(':') + if len(tmp) == 2: + step = int(tmp[1]) + if step > cur_step: + self.rates.append((float(tmp[0]), min(step, max_steps))) + self.maxit += 1 + if step > max_steps: + return + elif step == -1: + self.rates.append((float(tmp[0]), max_steps)) + self.maxit += 1 + return + else: + self.rates.append((float(tmp[0]), max_steps)) + self.maxit += 1 + return + + def __iter__(self): + return self + + def __next__(self): + if self.it < self.maxit: + self.it += 1 + return self.rates[self.it - 1] + else: + raise StopIteration diff --git a/modules/ui.py b/modules/ui.py index 8c06ad7c..c9e8355b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1047,7 +1047,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - learn_rate = gr.Number(label='Learning rate', value=5.0e-03) + learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) -- cgit v1.2.3 From 907a88b2d0be320575c2129d8d6a1d4f3a68f9eb Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Tue, 11 Oct 2022 06:33:08 +0900 Subject: Added .webp .bmp --- modules/textual_inversion/dataset.py | 2 +- modules/textual_inversion/preprocess.py | 2 +- modules/textual_inversion/textual_inversion.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index d4baf066..0dc54fb7 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -22,7 +22,7 @@ class PersonalizedBase(Dataset): self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) - self.extns = [".jpg",".jpeg",".png"] + self.extns = [".jpg",".jpeg",".png",".webp",".bmp"] self.dataset = [] diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index b6c78cf8..8290abe8 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,7 +12,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) - extns = [".jpg",".jpeg",".png"] + extns = [".jpg",".jpeg",".png",".webp",".bmp"] assert src != dst, 'same directory specified as source and destination' diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index a03b299c..33c923d1 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -161,7 +161,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps - extns = [".jpg",".jpeg",".png"] + extns = [".jpg",".jpeg",".png",".webp",".bmp"] filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') -- cgit v1.2.3 From a1a05ad2d13d0b995dbf8ecead6315f17837ef81 Mon Sep 17 00:00:00 2001 From: JC_Array Date: Mon, 10 Oct 2022 16:47:58 -0500 Subject: import time missing, added to deepbooru fixxing error on get_deepbooru_tags --- modules/deepbooru.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index cee4a3b4..12555b2e 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -1,6 +1,7 @@ import os.path from concurrent.futures import ProcessPoolExecutor import multiprocessing +import time def get_deepbooru_tags(pil_image, threshold=0.5): -- cgit v1.2.3 From b980e7188c671fc55b26557f097076fb5c976ba0 Mon Sep 17 00:00:00 2001 From: JC_Array Date: Mon, 10 Oct 2022 16:52:54 -0500 Subject: corrected tag return in get_deepbooru_tags --- modules/deepbooru.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 12555b2e..ebdba5e0 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -15,7 +15,6 @@ def get_deepbooru_tags(pil_image, threshold=0.5): while shared.deepbooru_process_return["value"] == -1: time.sleep(0.2) release_process() - return ret def deepbooru_process(queue, deepbooru_process_return, threshold): -- cgit v1.2.3 From 315d5a8ed975c88f670bc484f40a23fbf3a77b63 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 23:14:44 +0100 Subject: update data dis[play style --- modules/textual_inversion/textual_inversion.py | 88 +++++++++++++++++++------- 1 file changed, 65 insertions(+), 23 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 667a7cf2..95eebea7 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -39,20 +39,59 @@ def embeddingFromB64(data): d = base64.b64decode(data) return json.loads(d,cls=EmbeddingDecoder) -def appendImageDataFooter(image,data): +def xorBlock(block): + return np.bitwise_xor(block.astype(np.uint8), + ((np.random.RandomState(0xDEADBEEF).random(block.shape)*255).astype(np.uint8)) & 0x0F ) + +def styleBlock(block,sequence): + im = Image.new('RGB',(block.shape[1],block.shape[0])) + draw = ImageDraw.Draw(im) + i=0 + for x in range(-6,im.size[0],8): + for yi,y in enumerate(range(-6,im.size[1],8)): + offset=0 + if yi%2==0: + offset=4 + shade = sequence[i%len(sequence)] + i+=1 + draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) ) + + fg = np.array(im).astype(np.uint8) & 0xF0 + return block ^ fg + +def insertImageDataEmbed(image,data): d = 3 data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9) dnp = np.frombuffer(data_compressed,np.uint8).copy() - w = image.size[0] - next_size = dnp.shape[0] + (w-(dnp.shape[0]%w)) - next_size = next_size + ((w*d)-(next_size%(w*d))) - dnp.resize(next_size) - dnp = dnp.reshape((-1,w,d)) - print(dnp.shape) - im = Image.fromarray(dnp,mode='RGB') - background = Image.new('RGB',(image.size[0],image.size[1]+im.size[1]+1),(0,0,0)) - background.paste(image,(0,0)) - background.paste(im,(0,image.size[1]+1)) + dnphigh = dnp >> 4 + dnplow = dnp & 0x0F + + h = image.size[1] + next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h)) + next_size = next_size + ((h*d)-(next_size%(h*d))) + + dnplow.resize(next_size) + dnplow = dnplow.reshape((h,-1,d)) + + dnphigh.resize(next_size) + dnphigh = dnphigh.reshape((h,-1,d)) + + edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] + edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8) + + dnplow = styleBlock(dnplow,sequence=edgeStyleWeights) + dnplow = xorBlock(dnplow) + dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1]) + dnphigh = xorBlock(dnphigh) + + imlow = Image.fromarray(dnplow,mode='RGB') + imhigh = Image.fromarray(dnphigh,mode='RGB') + + background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0)) + background.paste(imlow,(0,0)) + background.paste(image,(imlow.size[0]+1,0)) + background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0)) + return background def crop_black(img,tol=0): @@ -62,19 +101,22 @@ def crop_black(img,tol=0): row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax() return img[row_start:row_end,col_start:col_end] -def extractImageDataFooter(image): +def extractImageDataEmbed(image): d=3 - outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) - lastRow = np.where( np.sum(outarr, axis=(1,2))==0) - if lastRow[0].shape[0] == 0: - print('Image data block not found.') + outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F + blackCols = np.where( np.sum(outarr, axis=(0,2))==0) + if blackCols[0].shape[0] < 2: + print('No Image data blocks found.') return None - lastRow = lastRow[0] - - lastRow = lastRow.max() - dataBlock = outarr[lastRow+1::].astype(np.uint8).flatten().tobytes() - print(lastRow) + dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8) + dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8) + + dataBlocklower = xorBlock(dataBlocklower) + dataBlockupper = xorBlock(dataBlockupper) + + dataBlock = (dataBlockupper << 4) | (dataBlocklower) + dataBlock = dataBlock.flatten().tobytes() data = zlib.decompress(dataBlock) return json.loads(data,cls=EmbeddingDecoder) @@ -154,7 +196,7 @@ class EmbeddingDatabase: data = embeddingFromB64(embed_image.text['sd-ti-embedding']) name = data.get('name',name) else: - data = extractImageDataFooter(embed_image) + data = extractImageDataEmbed(embed_image) name = data.get('name',name) else: data = torch.load(path, map_location="cpu") @@ -351,7 +393,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini footer_right = '{}'.format(embedding.step) captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) - captioned_image = appendImageDataFooter(captioned_image,data) + captioned_image = insertImageDataEmbed(captioned_image,data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) -- cgit v1.2.3 From 767202a4c324f9b49f63ab4dabbb5736fe9df6e5 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 23:20:52 +0100 Subject: add dependency --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 95eebea7..f3cacaa0 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,7 +7,7 @@ import tqdm import html import datetime -from PIL import Image,PngImagePlugin +from PIL import Image,PngImagePlugin,ImageDraw from ..images import captionImageOverlay import numpy as np import base64 -- cgit v1.2.3 From e0fbe6d27e7b4505766c8cb5a4264e1114cf3721 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 10 Oct 2022 23:26:24 +0100 Subject: colour depth conversion fix --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index f3cacaa0..ae807268 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -103,7 +103,7 @@ def crop_black(img,tol=0): def extractImageDataEmbed(image): d=3 - outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F + outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F blackCols = np.where( np.sum(outarr, axis=(0,2))==0) if blackCols[0].shape[0] < 2: print('No Image data blocks found.') -- cgit v1.2.3 From 76ef3d75f61253516c024553335d9083d9660a8a Mon Sep 17 00:00:00 2001 From: JC_Array Date: Mon, 10 Oct 2022 18:01:49 -0500 Subject: added deepbooru settings (threshold and sort by alpha or likelyhood) --- modules/deepbooru.py | 36 +++++++++++++++++++++++++----------- modules/shared.py | 6 ++++++ 2 files changed, 31 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index ebdba5e0..e31e92c0 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -3,31 +3,32 @@ from concurrent.futures import ProcessPoolExecutor import multiprocessing import time - -def get_deepbooru_tags(pil_image, threshold=0.5): +def get_deepbooru_tags(pil_image): """ This method is for running only one image at a time for simple use. Used to the img2img interrogate. """ from modules import shared # prevents circular reference - create_deepbooru_process(threshold) + create_deepbooru_process(shared.opts.deepbooru_threshold, shared.opts.deepbooru_sort_alpha) shared.deepbooru_process_return["value"] = -1 shared.deepbooru_process_queue.put(pil_image) while shared.deepbooru_process_return["value"] == -1: time.sleep(0.2) + tags = shared.deepbooru_process_return["value"] release_process() + return tags -def deepbooru_process(queue, deepbooru_process_return, threshold): +def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort): model, tags = get_deepbooru_tags_model() while True: # while process is running, keep monitoring queue for new image pil_image = queue.get() if pil_image == "QUIT": break else: - deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold) + deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort) -def create_deepbooru_process(threshold=0.5): +def create_deepbooru_process(threshold, alpha_sort): """ Creates deepbooru process. A queue is created to send images into the process. This enables multiple images to be processed in a row without reloading the model or creating a new process. To return the data, a shared @@ -40,7 +41,7 @@ def create_deepbooru_process(threshold=0.5): shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold)) + shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort)) shared.deepbooru_process.start() @@ -80,7 +81,7 @@ def get_deepbooru_tags_model(): return model, tags -def get_deepbooru_tags_from_model(model, tags, pil_image, threshold=0.5): +def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort): import deepdanbooru as dd import tensorflow as tf import numpy as np @@ -105,15 +106,28 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold=0.5): for i, tag in enumerate(tags): result_dict[tag] = y[i] - result_tags_out = [] + + unsorted_tags_in_theshold = [] result_tags_print = [] for tag in tags: if result_dict[tag] >= threshold: if tag.startswith("rating:"): continue - result_tags_out.append(tag) + unsorted_tags_in_theshold.append((result_dict[tag], tag)) result_tags_print.append(f'{result_dict[tag]} {tag}') + # sort tags + result_tags_out = [] + sort_ndx = 0 + print(alpha_sort) + if alpha_sort: + sort_ndx = 1 + + # sort by reverse by likelihood and normal for alpha + unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) + for weight, tag in unsorted_tags_in_theshold: + result_tags_out.append(tag) + print('\n'.join(sorted(result_tags_print, reverse=True))) - return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') \ No newline at end of file + return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') diff --git a/modules/shared.py b/modules/shared.py index 1995a99a..2e307809 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -261,6 +261,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) +if cmd_opts.deepdanbooru: + options_templates.update(options_section(('deepbooru-params', "DeepBooru parameters"), { + "deepbooru_sort_alpha": OptionInfo(True, "Sort Alphabetical", gr.Checkbox), + 'deepbooru_threshold': OptionInfo(0.5, "Threshold", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + })) + class Options: data = None -- cgit v1.2.3 From bb932dbf9faf43ba918daa4791873078797b2a48 Mon Sep 17 00:00:00 2001 From: JC_Array Date: Mon, 10 Oct 2022 18:37:52 -0500 Subject: added alpha sort and threshold variables to create process method in preprocessing --- modules/textual_inversion/preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 4a2194da..c0af729b 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -29,7 +29,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ shared.interrogator.load() if process_caption_deepbooru: - deepbooru.create_deepbooru_process() + deepbooru.create_deepbooru_process(opts.deepbooru_threshold, opts.deepbooru_sort_alpha) def save_pic_with_caption(image, index): if process_caption: -- cgit v1.2.3 From 1add3cff84b7e2436d69b1e97ae689281e4a7c33 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Mon, 10 Oct 2022 19:57:43 -0500 Subject: Refresh list of models/ckpts upon hitting restart gradio in the settings pane --- modules/ui.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index e8039d76..06ff118f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -39,6 +39,7 @@ import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui +from modules.sd_models import list_models # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -1290,6 +1291,9 @@ Requested path was: {f} shared.state.interrupt() settings_interface.gradio_ref.do_restart = True + # refresh models so that new models/.ckpt's show up on reload + list_models() + restart_gradio.click( fn=request_restart, inputs=[], -- cgit v1.2.3 From 7aa8fcac1e45c3ad9c6a40df0e44a346afcd5032 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 04:17:36 +0100 Subject: use simple lcg in xor --- modules/textual_inversion/textual_inversion.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ae807268..13416a08 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -39,9 +39,15 @@ def embeddingFromB64(data): d = base64.b64decode(data) return json.loads(d,cls=EmbeddingDecoder) +def lcg(m=2**32, a=1664525, c=1013904223, seed=0): + while True: + seed = (a * seed + c) % m + yield seed + def xorBlock(block): - return np.bitwise_xor(block.astype(np.uint8), - ((np.random.RandomState(0xDEADBEEF).random(block.shape)*255).astype(np.uint8)) & 0x0F ) + g = lcg() + randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape) + return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F) def styleBlock(block,sequence): im = Image.new('RGB',(block.shape[1],block.shape[0])) -- cgit v1.2.3 From 8b7d3f1bef47bbe048f644ed0d8dd3ad46554045 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Tue, 11 Oct 2022 02:22:46 -0300 Subject: Make the ctrl+enter shortcut use the generate button on the current tab --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index e8039d76..cafda884 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1331,7 +1331,7 @@ Requested path was: {f} with gr.Tabs() as tabs: for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid): + with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): interface.render() if os.path.exists(os.path.join(script_path, "notification.mp3")): -- cgit v1.2.3 From 8617396c6df71074c7fd3d39419802026874712a Mon Sep 17 00:00:00 2001 From: Kenneth Date: Mon, 10 Oct 2022 17:23:07 -0600 Subject: Added slider for deepbooru score threshold in settings --- modules/shared.py | 1 + modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index ecd15ef5..e0830e28 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -239,6 +239,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"), + "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), })) options_templates.update(options_section(('ui', "User interface"), { diff --git a/modules/ui.py b/modules/ui.py index cafda884..ca3151c4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -311,7 +311,7 @@ def interrogate(image): def interrogate_deepbooru(image): - prompt = get_deepbooru_tags(image) + prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold) return gr_show(True) if prompt is None else prompt -- cgit v1.2.3 From 5e2627a1a63e4c9f87e6e604ecc24e9936f149de Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Tue, 11 Oct 2022 07:55:28 +0100 Subject: Comma backtrack padding (#2192) Comma backtrack padding --- modules/sd_hijack.py | 19 ++++++++++++++++++- modules/shared.py | 1 + 2 files changed, 19 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 827bf304..aa4d2cbc 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -107,6 +107,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.tokenizer = wrapped.tokenizer self.token_mults = {} + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 @@ -136,6 +138,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): fixes = [] remade_tokens = [] multipliers = [] + last_comma = -1 for tokens, (text, weight) in zip(tokenized, parsed): i = 0 @@ -144,6 +147,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + if token == self.comma_token: + last_comma = len(remade_tokens) + elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: + last_comma += 1 + reloc_tokens = remade_tokens[last_comma:] + reloc_mults = multipliers[last_comma:] + + remade_tokens = remade_tokens[:last_comma] + length = len(remade_tokens) + + rem = int(math.ceil(length / 75)) * 75 - length + remade_tokens += [id_end] * rem + reloc_tokens + multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults + if embedding is None: remade_tokens.append(token) multipliers.append(weight) @@ -284,7 +301,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while max(map(len, remade_batch_tokens)) != 0: rem_tokens = [x[75:] for x in remade_batch_tokens] rem_multipliers = [x[75:] for x in batch_multipliers] - + self.hijack.fixes = [] for unfiltered in hijack_fixes: fixes = [] diff --git a/modules/shared.py b/modules/shared.py index e0830e28..14b40d70 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -227,6 +227,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), + "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), -- cgit v1.2.3 From 948533950c9db5069a874d925fadd50bac00fdb5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 11:09:51 +0300 Subject: replace duplicate code with a function --- modules/hypernetwork.py | 23 ++++++++++++-------- modules/sd_hijack_optimizations.py | 44 +++++++++++++------------------------- 2 files changed, 29 insertions(+), 38 deletions(-) (limited to 'modules') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 498bc9d8..7bbc443e 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -64,21 +64,26 @@ def load_hypernetwork(filename): shared.loaded_hypernetwork = None +def apply_hypernetwork(hypernetwork, context): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is None: + return context, context + + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v + + def attention_CrossAttention_forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k = self.to_k(hypernetwork_layers[0](context)) - v = self.to_v(hypernetwork_layers[1](context)) - else: - k = self.to_k(context) - v = self.to_v(context) + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) + v = self.to_v(context_v) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 18408e62..25cb67a4 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -8,7 +8,8 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared +from modules import shared, hypernetwork + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -26,16 +27,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) - del context, x + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + del context, context_k, context_v, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in @@ -59,22 +54,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): return self.to_out(r2) -# taken from https://github.com/Doggettx/stable-diffusion +# taken from https://github.com/Doggettx/stable-diffusion and modified def split_cross_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) k_in *= self.scale @@ -130,14 +119,11 @@ def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) + + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) -- cgit v1.2.3 From b2368a3bce663f19a7209d9cb38617e635ca6e3c Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Tue, 11 Oct 2022 17:32:46 +0900 Subject: Switched to exception handling --- modules/textual_inversion/dataset.py | 10 +++++----- modules/textual_inversion/preprocess.py | 8 +++++--- modules/textual_inversion/textual_inversion.py | 18 ++++++++---------- 3 files changed, 18 insertions(+), 18 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 0dc54fb7..4d006366 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -22,7 +22,6 @@ class PersonalizedBase(Dataset): self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) - self.extns = [".jpg",".jpeg",".png",".webp",".bmp"] self.dataset = [] @@ -33,12 +32,13 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' - self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in self.extns] + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): - image = Image.open(path) - image = image.convert('RGB') - image = image.resize((self.width, self.height), PIL.Image.BICUBIC) + try: + image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) + except Exception: + continue filename = os.path.basename(path) filename_tokens = os.path.splitext(filename)[0] diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 8290abe8..1a672725 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,13 +12,12 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) - extns = [".jpg",".jpeg",".png",".webp",".bmp"] assert src != dst, 'same directory specified as source and destination' os.makedirs(dst, exist_ok=True) - files = [i for i in os.listdir(src) if os.path.splitext(i.casefold())[1] in extns] + files = os.listdir(src) shared.state.textinfo = "Preprocessing..." shared.state.job_count = len(files) @@ -47,7 +46,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] filename = os.path.join(src, imagefile) - img = Image.open(filename).convert("RGB") + try: + img = Image.open(filename).convert("RGB") + except Exception: + continue if shared.state.interrupted: break diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 33c923d1..91cde04b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -161,7 +161,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps - extns = [".jpg",".jpeg",".png",".webp",".bmp"] filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') @@ -201,10 +200,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if ititial_step > steps: return embedding, filename - tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root) if os.path.splitext(file_path.casefold())[1] in extns]) - - epoch_len = (tr_img_len * num_repeats) + tr_img_len - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) for i, (x, text) in pbar: embedding.step = i + ititial_step @@ -228,10 +223,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini loss.backward() optimizer.step() - epoch_num = embedding.step // epoch_len - epoch_step = embedding.step - (epoch_num * epoch_len) + 1 + epoch_num = embedding.step // len(ds) + epoch_step = embedding.step - (epoch_num * len(ds)) + 1 - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}") + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}") if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') @@ -243,9 +238,12 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, prompt=text, - steps=20, - height=training_height, + steps=28, + height=768, width=training_width, + negative_prompt="lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry, artist name", + cfg_scale=7.0, + sampler_index=0, do_not_save_grid=True, do_not_save_samples=True, ) -- cgit v1.2.3 From 8bacbca0a1ab9aabcb0ad0cbf070e0006991e98a Mon Sep 17 00:00:00 2001 From: alg-wiki Date: Tue, 11 Oct 2022 17:35:09 +0900 Subject: Removed my local edits to checkpoint image generation --- modules/textual_inversion/textual_inversion.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 91cde04b..e9ff80c2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -238,12 +238,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, prompt=text, - steps=28, - height=768, + steps=20, + height=training_height, width=training_width, - negative_prompt="lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry, artist name", - cfg_scale=7.0, - sampler_index=0, do_not_save_grid=True, do_not_save_samples=True, ) -- cgit v1.2.3 From 530103b586109c11fd068eb70ef09503ec6a4caf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 14:53:02 +0300 Subject: fixes related to merge --- modules/hypernetwork.py | 103 ------------------------- modules/hypernetwork/hypernetwork.py | 74 +++++++++++------- modules/hypernetwork/ui.py | 10 +-- modules/sd_hijack_optimizations.py | 3 +- modules/shared.py | 13 +++- modules/textual_inversion/textual_inversion.py | 12 +-- modules/ui.py | 5 +- 7 files changed, 73 insertions(+), 147 deletions(-) delete mode 100644 modules/hypernetwork.py (limited to 'modules') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py deleted file mode 100644 index 7bbc443e..00000000 --- a/modules/hypernetwork.py +++ /dev/null @@ -1,103 +0,0 @@ -import glob -import os -import sys -import traceback - -import torch - -from ldm.util import default -from modules import devices, shared -import torch -from torch import einsum -from einops import rearrange, repeat - - -class HypernetworkModule(torch.nn.Module): - def __init__(self, dim, state_dict): - super().__init__() - - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) - - self.load_state_dict(state_dict, strict=True) - self.to(devices.device) - - def forward(self, x): - return x + (self.linear2(self.linear1(x))) - - -class Hypernetwork: - filename = None - name = None - - def __init__(self, filename): - self.filename = filename - self.name = os.path.splitext(os.path.basename(filename))[0] - self.layers = {} - - state_dict = torch.load(filename, map_location='cpu') - for size, sd in state_dict.items(): - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) - - -def list_hypernetworks(path): - res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): - name = os.path.splitext(os.path.basename(filename))[0] - res[name] = filename - return res - - -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - if path is not None: - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork(path) - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print(f"Unloading hypernetwork") - - shared.loaded_hypernetwork = None - - -def apply_hypernetwork(hypernetwork, context): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is None: - return context, context - - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) - return context_k, context_v - - -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context) - k = self.to_k(context_k) - v = self.to_v(context_v) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if mask is not None: - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py index a3d6a47e..aa701bda 100644 --- a/modules/hypernetwork/hypernetwork.py +++ b/modules/hypernetwork/hypernetwork.py @@ -26,10 +26,11 @@ class HypernetworkModule(torch.nn.Module): if state_dict is not None: self.load_state_dict(state_dict, strict=True) else: - self.linear1.weight.data.fill_(0.0001) - self.linear1.bias.data.fill_(0.0001) - self.linear2.weight.data.fill_(0.0001) - self.linear2.bias.data.fill_(0.0001) + + self.linear1.weight.data.normal_(mean=0.0, std=0.01) + self.linear1.bias.data.zero_() + self.linear2.weight.data.normal_(mean=0.0, std=0.01) + self.linear2.bias.data.zero_() self.to(devices.device) @@ -92,41 +93,54 @@ class Hypernetwork: self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) -def load_hypernetworks(path): +def list_hypernetworks(path): res = {} + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res - for filename in glob.iglob(path + '**/*.pt', recursive=True): + +def load_hypernetwork(filename): + path = shared.hypernetworks.get(filename, None) + if path is not None: + print(f"Loading hypernetwork {filename}") try: - hn = Hypernetwork() - hn.load(filename) - res[hn.name] = hn + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + except Exception: - print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") - return res + shared.loaded_hypernetwork = None -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads +def apply_hypernetwork(hypernetwork, context, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - q = self.to_q(x) - context = default(context, x) + if hypernetwork_layers is None: + return context, context - hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None) + if layer is not None: + layer.hyper_k = hypernetwork_layers[0] + layer.hyper_v = hypernetwork_layers[1] - if hypernetwork_layers is not None: - hypernetwork_k, hypernetwork_v = hypernetwork_layers + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v - self.hypernetwork_k = hypernetwork_k - self.hypernetwork_v = hypernetwork_v - context_k = hypernetwork_k(context) - context_v = hypernetwork_v(context) - else: - context_k = context - context_v = context +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) k = self.to_k(context_k) v = self.to_v(context_v) @@ -151,7 +165,9 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): assert hypernetwork_name, 'embedding not selected' - shared.hypernetwork = shared.hypernetworks[hypernetwork_name] + path = shared.hypernetworks.get(hypernetwork_name, None) + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) shared.state.textinfo = "Initializing hypernetwork training..." shared.state.job_count = steps @@ -176,9 +192,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) - hypernetwork = shared.hypernetworks[hypernetwork_name] + hypernetwork = shared.loaded_hypernetwork weights = hypernetwork.weights() for weight in weights: weight.requires_grad = True @@ -194,7 +210,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps: return hypernetwork, filename - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, (x, text) in pbar: hypernetwork.step = i + ititial_step diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py index 525f978c..f6d1d0a3 100644 --- a/modules/hypernetwork/ui.py +++ b/modules/hypernetwork/ui.py @@ -6,24 +6,24 @@ import gradio as gr import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess from modules import sd_hijack, shared +from modules.hypernetwork import hypernetwork def create_hypernetwork(name): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" - hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) - hypernetwork.save(fn) + hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet.save(fn) shared.reload_hypernetworks() - shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" def train_hypernetwork(*args): - initial_hypernetwork = shared.hypernetwork + initial_hypernetwork = shared.loaded_hypernetwork try: sd_hijack.undo_optimizations() @@ -38,6 +38,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception: raise finally: - shared.hypernetwork = initial_hypernetwork + shared.loaded_hypernetwork = initial_hypernetwork sd_hijack.apply_optimizations() diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 25cb67a4..27e571fc 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -8,7 +8,8 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, hypernetwork +from modules import shared +from modules.hypernetwork import hypernetwork if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: diff --git a/modules/shared.py b/modules/shared.py index 14b40d70..8753015e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,8 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, hypernetwork +from modules import sd_samplers +from modules.hypernetwork import hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -29,6 +30,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") @@ -82,10 +84,17 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False config_filename = cmd_opts.ui_settings_file -hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks')) +hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None +def reload_hypernetworks(): + global hypernetworks + + hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) + hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + + class State: skipped = False interrupted = False diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5965c5a0..d6977950 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -238,12 +238,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') + preview_text = text if preview_image_prompt == "" else preview_image_prompt + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, - prompt=text, + prompt=preview_text, steps=20, - height=training_height, - width=training_width, + height=training_height, + width=training_width, do_not_save_grid=True, do_not_save_samples=True, ) @@ -254,7 +256,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.current_image = image image.save(last_saved_image) - last_saved_image += f", prompt: {text}" + last_saved_image += f", prompt: {preview_text}" shared.state.job_no = embedding.step diff --git a/modules/ui.py b/modules/ui.py index 10b1ee3a..df653059 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1023,7 +1023,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_embedding = gr.Button(value="Create", variant='primary') + create_embedding = gr.Button(value="Create embedding", variant='primary') with gr.Group(): gr.HTML(value="

Create a new hypernetwork

") @@ -1035,7 +1035,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_hypernetwork = gr.Button(value="Create", variant='primary') + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') with gr.Group(): gr.HTML(value="

Preprocess images

") @@ -1147,6 +1147,7 @@ def create_ui(wrap_gradio_gpu_call): create_image_every, save_embedding_every, template_file, + preview_image_prompt, ], outputs=[ ti_output, -- cgit v1.2.3 From 7b1db45e1fda8603d4617affd976066be5e5b821 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Tue, 11 Oct 2022 20:17:27 +0800 Subject: images history improvement --- modules/images_history.py | 229 ++++++++++++++++++++++++---------------------- 1 file changed, 121 insertions(+), 108 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 0e0a48f3..01d11a01 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -1,117 +1,130 @@ import os def get_recent_images(dir_name, page_index, step, image_index): - print(image_index) - page_index = int(page_index) - f_list = os.listdir(dir_name) - file_list = [] - for file in f_list: - if file[-4:] == ".txt": - continue - file_list.append(file) - file_list = sorted(file_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) - num = 48 - max_page_index = len(file_list) // num + 1 - page_index = max_page_index if page_index == -1 else page_index + step - page_index = 1 if page_index < 1 else page_index - page_index = max_page_index if page_index > max_page_index else page_index - idx_frm = (page_index - 1) * num - file_list = file_list[idx_frm:idx_frm + num] - print(f"Loading history page {page_index}") - image_index = int(image_index) - if image_index < 0 or image_index > len(file_list) - 1: - current_file = None - hide_image = None - else: - current_file = file_list[int(image_index)] - hide_image = os.path.join(dir_name, current_file) - return [os.path.join(dir_name, file) for file in file_list], page_index, file_list, current_file, hide_image -def first_page_click(dir_name, page_index, image_index): - return get_recent_images(dir_name, 1, 0, image_index) -def end_page_click(dir_name, page_index, image_index): - return get_recent_images(dir_name, -1, 0, image_index) -def prev_page_click(dir_name, page_index, image_index): - return get_recent_images(dir_name, page_index, -1, image_index) -def next_page_click(dir_name, page_index, image_index): - return get_recent_images(dir_name, page_index, 1, image_index) -def page_index_change(dir_name, page_index, image_index): - return get_recent_images(dir_name, page_index, 0, image_index) + #print(image_index) + page_index = int(page_index) + f_list = os.listdir(dir_name) + file_list = [] + for file in f_list: + if file[-4:] == ".txt": + continue + file_list.append(file) + file_list = sorted(file_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) + num = 48 + max_page_index = len(file_list) // num + 1 + page_index = max_page_index if page_index == -1 else page_index + step + page_index = 1 if page_index < 1 else page_index + page_index = max_page_index if page_index > max_page_index else page_index + idx_frm = (page_index - 1) * num + file_list = file_list[idx_frm:idx_frm + num] + #print(f"Loading history page {page_index}") + image_index = int(image_index) + if image_index < 0 or image_index > len(file_list) - 1: + current_file = None + hide_image = None + else: + current_file = file_list[int(image_index)] + hide_image = os.path.join(dir_name, current_file) + return [os.path.join(dir_name, file) for file in file_list], page_index, file_list, current_file, hide_image +def first_page_click(dir_name, page_index, image_index, tabname): + return get_recent_images(dir_name, 1, 0, image_index) +def end_page_click(dir_name, page_index, image_index, tabname): + return get_recent_images(dir_name, -1, 0, image_index) +def prev_page_click(dir_name, page_index, image_index, tabname): + return get_recent_images(dir_name, page_index, -1, image_index) +def next_page_click(dir_name, page_index, image_index, tabname): + return get_recent_images(dir_name, page_index, 1, image_index) +def page_index_change(dir_name, page_index, image_index, tabname): + return get_recent_images(dir_name, page_index, 0, image_index) def show_image_info(num, image_path, filenames): - file = filenames[int(num)] - return file, num, os.path.join(image_path, file) -def delete_image(is_img2img, dir_name, name, page_index, filenames, image_index): - print("filename", name) - path = os.path.join(dir_name, name) - if os.path.exists(path): - print(f"Delete file {path}") - os.remove(path) - images, page_index, file_list, current_file, hide_image = get_recent_images(dir_name, page_index, 0, image_index) - return images, page_index, file_list, current_file, hide_image + #print("set img",num) + file = filenames[int(num)] + return file, num, os.path.join(image_path, file) +def delete_image(tabname, dir_name, name, page_index, filenames, image_index): + #print("filename", name) + path = os.path.join(dir_name, name) + if os.path.exists(path): + print(f"Delete file {path}") + os.remove(path) + new_file_list = [] + for f in filenames: + if f == name: + continue + new_file_list.append(f) + else: + print(f"Not exists file {path}") + new_file_list = filenames + return page_index, new_file_list +def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): + if tabname == "txt2img": + dir_name = opts.outdir_txt2img_samples + elif tabname == "img2img": + dir_name = opts.outdir_img2img_samples + elif tabname == "extras": + dir_name = opts.outdir_extras_samples + with gr.Row(): + renew_page = gr.Button('Renew', elem_id=tabname + "_images_history_renew_page") + first_page = gr.Button('First', elem_id=tabname + "_images_history_first_page") + prev_page = gr.Button('Prev') + page_index = gr.Number(value=1, label="Page Index") + next_page = gr.Button('Next', elem_id=tabname + "_images_history_next_page") + end_page = gr.Button('End') + with gr.Row(elem_id=tabname + "_images_history"): + with gr.Row(): + with gr.Column(): + history_gallery = gr.Gallery(show_label=False).style(grid=6) + with gr.Column(): + with gr.Row(): + delete = gr.Button('Delete') + pnginfo_send_to_txt2img = gr.Button('Send to txt2img') + pnginfo_send_to_img2img = gr.Button('Send to img2img') + with gr.Row(): + with gr.Column(): + img_file_info = gr.Textbox(label="Generate Info") + img_file_name = gr.Textbox(label="File Name") + with gr.Row(): + # hiden items + img_path = gr.Textbox(dir_name, visible=False) + tabname_box = gr.Textbox(tabname, visible=False) + image_index = gr.Textbox(value=-1, visible=False) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) + filenames = gr.State() + hide_image = gr.Image(visible=False, type="pil") + info1 = gr.Textbox(visible=False) + info2 = gr.Textbox(visible=False) -def show_images_history(gr, opts, is_img2img, run_pnginfo, switch_dict): - def id_name(is_img2img, name): - return ("img2img" if is_img2img else "txt2img") + "_" + name - if is_img2img: - dir_name = opts.outdir_img2img_samples - else: - dir_name = opts.outdir_txt2img_samples - with gr.Row(): - first_page = gr.Button('First', elem_id=id_name(is_img2img,"images_history_first_page")) - prev_page = gr.Button('Prev') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next') - end_page = gr.Button('End') - with gr.Row(elem_id=id_name(is_img2img,"images_history")): - with gr.Row(): - with gr.Column(): - history_gallery = gr.Gallery(show_label=False).style(grid=6) - with gr.Column(): - with gr.Row(): - delete = gr.Button('Delete') - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') - with gr.Row(): - with gr.Column(): - img_file_info = gr.Textbox(dir_name, label="Generate Info") - img_file_name = gr.Textbox(label="File Name") - with gr.Row(): - # hiden items - img_path = gr.Textbox(dir_name, visible=False) - is_img2img_flag = gr.Checkbox(is_img2img, visible=False) - image_index = gr.Textbox(value=-1, visible=False) - set_index = gr.Button('set_index', elem_id=id_name(is_img2img,"images_history_set_index")) - filenames = gr.State() - hide_image = gr.Image(visible=False, type="pil") - info1 = gr.Textbox(visible=False) - info2 = gr.Textbox(visible=False) + + # turn pages + gallery_inputs = [img_path, page_index, image_index, tabname_box] + gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hide_image] - - # turn pages - gallery_inputs = [img_path, page_index, image_index] - gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hide_image] - first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) - #page_index.change(page_index_change, inputs=[is_img2img_flag, img_path, page_index], outputs=[history_gallery, page_index]) + first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + #page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index]) - #other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[is_img2img_flag, img_path, filenames], outputs=[img_file_name, image_index, hide_image]) - delete.click(delete_image, inputs=[is_img2img_flag, img_path, img_file_name, page_index, filenames, image_index], outputs=gallery_outputs) - hide_image.change(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) - switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') - switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') - - + #other funcitons + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hide_image]) + delete.click(delete_image,_js="images_history_delete", inputs=[tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[page_index, filenames]) + hide_image.change(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) + switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') + switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') + + def create_history_tabs(gr, opts, run_pnginfo, switch_dict): - with gr.Blocks(analytics_enabled=False) as images_history: - with gr.Tabs() as tabs: - with gr.Tab("txt2img history", id="images_history_txt2img"): - with gr.Blocks(analytics_enabled=False) as images_history_txt2img: - show_images_history(gr, opts, False, run_pnginfo, switch_dict) - with gr.Tab("img2img history", id="images_history_img2img"): - with gr.Blocks(analytics_enabled=False) as images_history_img2img: - show_images_history(gr, opts, True, run_pnginfo, switch_dict) - return images_history + with gr.Blocks(analytics_enabled=False) as images_history: + with gr.Tabs() as tabs: + with gr.Tab("txt2img history"): + with gr.Blocks(analytics_enabled=False) as images_history_txt2img: + show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict) + with gr.Tab("img2img history"): + with gr.Blocks(analytics_enabled=False) as images_history_img2img: + show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict) + with gr.Tab("extras history"): + with gr.Blocks(analytics_enabled=False) as images_history_img2img: + show_images_history(gr, opts, "extras", run_pnginfo, switch_dict) + return images_history -- cgit v1.2.3 From 92d7a138857b308c97a8d009848f642aeb93d6c8 Mon Sep 17 00:00:00 2001 From: Martin Cairns Date: Tue, 11 Oct 2022 00:02:44 +0100 Subject: Handle different parameters for DPM fast & adaptive --- modules/sd_samplers.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d168b938..eee52e7d 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -57,7 +57,7 @@ def set_samplers(): global samplers, samplers_for_img2img hidden = set(opts.hide_samplers) - hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive']) + hidden_img2img = set(opts.hide_samplers + ['PLMS']) samplers = [x for x in all_samplers if x.name not in hidden] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] @@ -365,16 +365,27 @@ class KDiffusionSampler: else: sigmas = self.model_wrap.get_sigmas(steps) - noise = noise * sigmas[steps - t_enc - 1] - xi = x + noise - - extra_params_kwargs = self.initialize(p) - sigma_sched = sigmas[steps - t_enc - 1:] + print('check values same', sigmas[steps - t_enc - 1] , sigma_sched[0], sigmas[steps - t_enc - 1] - sigma_sched[0]) + xi = x + noise * sigma_sched[0] + + extra_params_kwargs = self.initialize(p) + if 'sigma_min' in inspect.signature(self.func).parameters: + ## last sigma is zero which is allowed by DPM Fast & Adaptive so taking value before last + extra_params_kwargs['sigma_min'] = sigma_sched[-2] + if 'sigma_max' in inspect.signature(self.func).parameters: + extra_params_kwargs['sigma_max'] = sigma_sched[0] + if 'n' in inspect.signature(self.func).parameters: + extra_params_kwargs['n'] = len(sigma_sched) - 1 + if 'sigma_sched' in inspect.signature(self.func).parameters: + extra_params_kwargs['sigma_sched'] = sigma_sched + if 'sigmas' in inspect.signature(self.func).parameters: + extra_params_kwargs['sigmas'] = sigma_sched self.model_wrap_cfg.init_latent = x - return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) + return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): steps = steps or p.steps -- cgit v1.2.3 From 1eae3076078f00ecc5d0fac3c77fffb85cd2eb77 Mon Sep 17 00:00:00 2001 From: Martin Cairns Date: Tue, 11 Oct 2022 00:04:06 +0100 Subject: Remove debug code for checking that first sigma value is same after code cleanup --- modules/sd_samplers.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index eee52e7d..32272916 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -366,7 +366,6 @@ class KDiffusionSampler: sigmas = self.model_wrap.get_sigmas(steps) sigma_sched = sigmas[steps - t_enc - 1:] - print('check values same', sigmas[steps - t_enc - 1] , sigma_sched[0], sigmas[steps - t_enc - 1] - sigma_sched[0]) xi = x + noise * sigma_sched[0] extra_params_kwargs = self.initialize(p) -- cgit v1.2.3 From eacc03b16730bcc5be95cda2d7c966ff1b4a8263 Mon Sep 17 00:00:00 2001 From: Martin Cairns Date: Tue, 11 Oct 2022 00:36:00 +0100 Subject: Fix typo in comments --- modules/sd_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 32272916..20309e06 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -370,7 +370,7 @@ class KDiffusionSampler: extra_params_kwargs = self.initialize(p) if 'sigma_min' in inspect.signature(self.func).parameters: - ## last sigma is zero which is allowed by DPM Fast & Adaptive so taking value before last + ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last extra_params_kwargs['sigma_min'] = sigma_sched[-2] if 'sigma_max' in inspect.signature(self.func).parameters: extra_params_kwargs['sigma_max'] = sigma_sched[0] -- cgit v1.2.3 From 87d63bbab5c973ac5cec777ef7304d28f1ab3f24 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Tue, 11 Oct 2022 20:37:03 +0800 Subject: images history improvement --- modules/images_history.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 01d11a01..23f55b30 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -64,12 +64,12 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): elif tabname == "extras": dir_name = opts.outdir_extras_samples with gr.Row(): - renew_page = gr.Button('Renew', elem_id=tabname + "_images_history_renew_page") - first_page = gr.Button('First', elem_id=tabname + "_images_history_first_page") - prev_page = gr.Button('Prev') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next', elem_id=tabname + "_images_history_next_page") - end_page = gr.Button('End') + renew_page = gr.Button('Renew', elem_id=tabname + "_images_history_renew_page") + first_page = gr.Button('First', elem_id=tabname + "_images_history_first_page") + prev_page = gr.Button('Prev') + page_index = gr.Number(value=1, label="Page Index") + next_page = gr.Button('Next', elem_id=tabname + "_images_history_next_page") + end_page = gr.Button('End') with gr.Row(elem_id=tabname + "_images_history"): with gr.Row(): with gr.Column(): @@ -84,15 +84,15 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): img_file_info = gr.Textbox(label="Generate Info") img_file_name = gr.Textbox(label="File Name") with gr.Row(): - # hiden items - img_path = gr.Textbox(dir_name, visible=False) - tabname_box = gr.Textbox(tabname, visible=False) - image_index = gr.Textbox(value=-1, visible=False) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) - filenames = gr.State() - hide_image = gr.Image(visible=False, type="pil") - info1 = gr.Textbox(visible=False) - info2 = gr.Textbox(visible=False) + # hiden items + img_path = gr.Textbox(dir_name, visible=False) + tabname_box = gr.Textbox(tabname, visible=False) + image_index = gr.Textbox(value=-1, visible=False) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) + filenames = gr.State() + hide_image = gr.Image(visible=False, type="pil") + info1 = gr.Textbox(visible=False) + info2 = gr.Textbox(visible=False) # turn pages -- cgit v1.2.3 From 87b77cad5f3017c952a7dfec0e7904a9df5b72fd Mon Sep 17 00:00:00 2001 From: Ben <110583491+TheLastBen@users.noreply.github.com> Date: Mon, 10 Oct 2022 19:37:16 +0100 Subject: Layout fix --- modules/ui.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index df653059..de4cd7f2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -550,15 +550,15 @@ def create_ui(wrap_gradio_gpu_call): button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id) - with gr.Row(): - do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) + with gr.Row(): + do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) - with gr.Group(): - html_info = gr.HTML() - generation_info = gr.Textbox(visible=False) + with gr.Group(): + html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -738,15 +738,15 @@ def create_ui(wrap_gradio_gpu_call): button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id) - with gr.Row(): - do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) + with gr.Row(): + do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) - with gr.Group(): - html_info = gr.HTML() - generation_info = gr.Textbox(visible=False) + with gr.Group(): + html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) -- cgit v1.2.3 From 861297cefe2bb663f4e09dd4778a4cb93ebe8ff1 Mon Sep 17 00:00:00 2001 From: Ben <110583491+TheLastBen@users.noreply.github.com> Date: Tue, 11 Oct 2022 08:08:45 +0100 Subject: add a space holder --- modules/ui.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index de4cd7f2..fc0f3d3c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,7 +429,10 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=8): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) + with gr.Row(): + negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) + with gr.Column(scale=1, elem_id="roll_col"): + sh = gr.Button(elem_id="sh", visible=True) with gr.Column(scale=1, elem_id="style_neg_col"): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) -- cgit v1.2.3 From 59925644480b6fd84f6bb84b4df7d4fbc6a0cce8 Mon Sep 17 00:00:00 2001 From: JamnedZ Date: Tue, 11 Oct 2022 16:40:27 +0700 Subject: Cleaned ngrok integration --- modules/ngrok.py | 15 +++++++++++++++ modules/shared.py | 1 + modules/ui.py | 5 +++++ 3 files changed, 21 insertions(+) create mode 100644 modules/ngrok.py (limited to 'modules') diff --git a/modules/ngrok.py b/modules/ngrok.py new file mode 100644 index 00000000..17e6976f --- /dev/null +++ b/modules/ngrok.py @@ -0,0 +1,15 @@ +from pyngrok import ngrok, conf, exception + + +def connect(token, port): + if token == None: + token = 'None' + conf.get_default().auth_token = token + try: + public_url = ngrok.connect(port).public_url + except exception.PyngrokNgrokError: + print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' + f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') + else: + print(f'ngrok connected to localhost:{port}! URL: {public_url}\n' + 'You can use this link after the launch is complete.') \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index 8753015e..375e3afb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -38,6 +38,7 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") +parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) diff --git a/modules/ui.py b/modules/ui.py index fc0f3d3c..f57f32db 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -51,6 +51,11 @@ if not cmd_opts.share and not cmd_opts.listen: gradio.utils.version_check = lambda: None gradio.utils.get_local_ip_address = lambda: '127.0.0.1' +if cmd_opts.ngrok != None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860) + def gr_show(visible=True): return {"visible": visible, "__type__": "update"} -- cgit v1.2.3 From a004d1a855311b0d7ff2976a4e31b0247ad9d1f6 Mon Sep 17 00:00:00 2001 From: JamnedZ Date: Tue, 11 Oct 2022 16:48:27 +0700 Subject: Added new line at the end of ngrok.py --- modules/ngrok.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ngrok.py b/modules/ngrok.py index 17e6976f..7d03a6df 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -12,4 +12,4 @@ def connect(token, port): f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') else: print(f'ngrok connected to localhost:{port}! URL: {public_url}\n' - 'You can use this link after the launch is complete.') \ No newline at end of file + 'You can use this link after the launch is complete.') -- cgit v1.2.3 From 873efeed49bb5197a42da18272115b326c5d68f3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 15:51:22 +0300 Subject: rename hypernetwork dir to hypernetworks to prevent clash with an old filename that people who use zip instead of git clone will have --- modules/hypernetwork/hypernetwork.py | 283 ---------------------------------- modules/hypernetwork/ui.py | 43 ------ modules/hypernetworks/hypernetwork.py | 283 ++++++++++++++++++++++++++++++++++ modules/hypernetworks/ui.py | 43 ++++++ modules/sd_hijack.py | 2 +- modules/sd_hijack_optimizations.py | 2 +- modules/shared.py | 2 +- modules/ui.py | 2 +- 8 files changed, 330 insertions(+), 330 deletions(-) delete mode 100644 modules/hypernetwork/hypernetwork.py delete mode 100644 modules/hypernetwork/ui.py create mode 100644 modules/hypernetworks/hypernetwork.py create mode 100644 modules/hypernetworks/ui.py (limited to 'modules') diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py deleted file mode 100644 index aa701bda..00000000 --- a/modules/hypernetwork/hypernetwork.py +++ /dev/null @@ -1,283 +0,0 @@ -import datetime -import glob -import html -import os -import sys -import traceback -import tqdm - -import torch - -from ldm.util import default -from modules import devices, shared, processing, sd_models -import torch -from torch import einsum -from einops import rearrange, repeat -import modules.textual_inversion.dataset - - -class HypernetworkModule(torch.nn.Module): - def __init__(self, dim, state_dict=None): - super().__init__() - - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) - - if state_dict is not None: - self.load_state_dict(state_dict, strict=True) - else: - - self.linear1.weight.data.normal_(mean=0.0, std=0.01) - self.linear1.bias.data.zero_() - self.linear2.weight.data.normal_(mean=0.0, std=0.01) - self.linear2.bias.data.zero_() - - self.to(devices.device) - - def forward(self, x): - return x + (self.linear2(self.linear1(x))) - - -class Hypernetwork: - filename = None - name = None - - def __init__(self, name=None): - self.filename = None - self.name = name - self.layers = {} - self.step = 0 - self.sd_checkpoint = None - self.sd_checkpoint_name = None - - for size in [320, 640, 768, 1280]: - self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) - - def weights(self): - res = [] - - for k, layers in self.layers.items(): - for layer in layers: - layer.train() - res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] - - return res - - def save(self, filename): - state_dict = {} - - for k, v in self.layers.items(): - state_dict[k] = (v[0].state_dict(), v[1].state_dict()) - - state_dict['step'] = self.step - state_dict['name'] = self.name - state_dict['sd_checkpoint'] = self.sd_checkpoint - state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name - - torch.save(state_dict, filename) - - def load(self, filename): - self.filename = filename - if self.name is None: - self.name = os.path.splitext(os.path.basename(filename))[0] - - state_dict = torch.load(filename, map_location='cpu') - - for size, sd in state_dict.items(): - if type(size) == int: - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) - - self.name = state_dict.get('name', self.name) - self.step = state_dict.get('step', 0) - self.sd_checkpoint = state_dict.get('sd_checkpoint', None) - self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) - - -def list_hypernetworks(path): - res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): - name = os.path.splitext(os.path.basename(filename))[0] - res[name] = filename - return res - - -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - if path is not None: - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) - - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print(f"Unloading hypernetwork") - - shared.loaded_hypernetwork = None - - -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is None: - return context, context - - if layer is not None: - layer.hyper_k = hypernetwork_layers[0] - layer.hyper_v = hypernetwork_layers[1] - - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) - return context_k, context_v - - -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) - k = self.to_k(context_k) - v = self.to_v(context_v) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if mask is not None: - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) - - -def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): - assert hypernetwork_name, 'embedding not selected' - - path = shared.hypernetworks.get(hypernetwork_name, None) - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) - - shared.state.textinfo = "Initializing hypernetwork training..." - shared.state.job_count = steps - - filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') - - log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) - - if save_hypernetwork_every > 0: - hypernetwork_dir = os.path.join(log_directory, "hypernetworks") - os.makedirs(hypernetwork_dir, exist_ok=True) - else: - hypernetwork_dir = None - - if create_image_every > 0: - images_dir = os.path.join(log_directory, "images") - os.makedirs(images_dir, exist_ok=True) - else: - images_dir = None - - cond_model = shared.sd_model.cond_stage_model - - shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." - with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) - - hypernetwork = shared.loaded_hypernetwork - weights = hypernetwork.weights() - for weight in weights: - weight.requires_grad = True - - optimizer = torch.optim.AdamW(weights, lr=learn_rate) - - losses = torch.zeros((32,)) - - last_saved_file = "" - last_saved_image = "" - - ititial_step = hypernetwork.step or 0 - if ititial_step > steps: - return hypernetwork, filename - - pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text) in pbar: - hypernetwork.step = i + ititial_step - - if hypernetwork.step > steps: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = cond_model([text]) - - x = x.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] - del x - - losses[hypernetwork.step % losses.shape[0]] = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - pbar.set_description(f"loss: {losses.mean():.7f}") - - if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') - hypernetwork.save(last_saved_file) - - if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: - last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - - preview_text = text if preview_image_prompt == "" else preview_image_prompt - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - prompt=preview_text, - steps=20, - do_not_save_grid=True, - do_not_save_samples=True, - ) - - processed = processing.process_images(p) - image = processed.images[0] - - shared.state.current_image = image - image.save(last_saved_image) - - last_saved_image += f", prompt: {preview_text}" - - shared.state.job_no = hypernetwork.step - - shared.state.textinfo = f""" -

-Loss: {losses.mean():.7f}
-Step: {hypernetwork.step}
-Last prompt: {html.escape(text)}
-Last saved embedding: {html.escape(last_saved_file)}
-Last saved image: {html.escape(last_saved_image)}
-

-""" - - checkpoint = sd_models.select_checkpoint() - - hypernetwork.sd_checkpoint = checkpoint.hash - hypernetwork.sd_checkpoint_name = checkpoint.model_name - hypernetwork.save(filename) - - return hypernetwork, filename - - diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py deleted file mode 100644 index f6d1d0a3..00000000 --- a/modules/hypernetwork/ui.py +++ /dev/null @@ -1,43 +0,0 @@ -import html -import os - -import gradio as gr - -import modules.textual_inversion.textual_inversion -import modules.textual_inversion.preprocess -from modules import sd_hijack, shared -from modules.hypernetwork import hypernetwork - - -def create_hypernetwork(name): - fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") - assert not os.path.exists(fn), f"file {fn} already exists" - - hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) - hypernet.save(fn) - - shared.reload_hypernetworks() - - return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" - - -def train_hypernetwork(*args): - - initial_hypernetwork = shared.loaded_hypernetwork - - try: - sd_hijack.undo_optimizations() - - hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) - - res = f""" -Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. -Hypernetwork saved to {html.escape(filename)} -""" - return res, "" - except Exception: - raise - finally: - shared.loaded_hypernetwork = initial_hypernetwork - sd_hijack.apply_optimizations() - diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py new file mode 100644 index 00000000..aa701bda --- /dev/null +++ b/modules/hypernetworks/hypernetwork.py @@ -0,0 +1,283 @@ +import datetime +import glob +import html +import os +import sys +import traceback +import tqdm + +import torch + +from ldm.util import default +from modules import devices, shared, processing, sd_models +import torch +from torch import einsum +from einops import rearrange, repeat +import modules.textual_inversion.dataset + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict=None): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + if state_dict is not None: + self.load_state_dict(state_dict, strict=True) + else: + + self.linear1.weight.data.normal_(mean=0.0, std=0.01) + self.linear1.bias.data.zero_() + self.linear2.weight.data.normal_(mean=0.0, std=0.01) + self.linear2.bias.data.zero_() + + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, name=None): + self.filename = None + self.name = name + self.layers = {} + self.step = 0 + self.sd_checkpoint = None + self.sd_checkpoint_name = None + + for size in [320, 640, 768, 1280]: + self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) + + def weights(self): + res = [] + + for k, layers in self.layers.items(): + for layer in layers: + layer.train() + res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] + + return res + + def save(self, filename): + state_dict = {} + + for k, v in self.layers.items(): + state_dict[k] = (v[0].state_dict(), v[1].state_dict()) + + state_dict['step'] = self.step + state_dict['name'] = self.name + state_dict['sd_checkpoint'] = self.sd_checkpoint + state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name + + torch.save(state_dict, filename) + + def load(self, filename): + self.filename = filename + if self.name is None: + self.name = os.path.splitext(os.path.basename(filename))[0] + + state_dict = torch.load(filename, map_location='cpu') + + for size, sd in state_dict.items(): + if type(size) == int: + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + self.name = state_dict.get('name', self.name) + self.step = state_dict.get('step', 0) + self.sd_checkpoint = state_dict.get('sd_checkpoint', None) + self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) + + +def list_hypernetworks(path): + res = {} + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res + + +def load_hypernetwork(filename): + path = shared.hypernetworks.get(filename, None) + if path is not None: + print(f"Loading hypernetwork {filename}") + try: + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + + except Exception: + print(f"Error loading hypernetwork {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") + + shared.loaded_hypernetwork = None + + +def apply_hypernetwork(hypernetwork, context, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is None: + return context, context + + if layer is not None: + layer.hyper_k = hypernetwork_layers[0] + layer.hyper_v = hypernetwork_layers[1] + + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v + + +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) + k = self.to_k(context_k) + v = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): + assert hypernetwork_name, 'embedding not selected' + + path = shared.hypernetworks.get(hypernetwork_name, None) + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + + shared.state.textinfo = "Initializing hypernetwork training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) + + if save_hypernetwork_every > 0: + hypernetwork_dir = os.path.join(log_directory, "hypernetworks") + os.makedirs(hypernetwork_dir, exist_ok=True) + else: + hypernetwork_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + cond_model = shared.sd_model.cond_stage_model + + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + + hypernetwork = shared.loaded_hypernetwork + weights = hypernetwork.weights() + for weight in weights: + weight.requires_grad = True + + optimizer = torch.optim.AdamW(weights, lr=learn_rate) + + losses = torch.zeros((32,)) + + last_saved_file = "" + last_saved_image = "" + + ititial_step = hypernetwork.step or 0 + if ititial_step > steps: + return hypernetwork, filename + + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) + for i, (x, text) in pbar: + hypernetwork.step = i + ititial_step + + if hypernetwork.step > steps: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([text]) + + x = x.to(devices.device) + loss = shared.sd_model(x.unsqueeze(0), c)[0] + del x + + losses[hypernetwork.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description(f"loss: {losses.mean():.7f}") + + if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') + hypernetwork.save(last_saved_file) + + if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: + last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + + preview_text = text if preview_image_prompt == "" else preview_image_prompt + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + prompt=preview_text, + steps=20, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + processed = processing.process_images(p) + image = processed.images[0] + + shared.state.current_image = image + image.save(last_saved_image) + + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = hypernetwork.step + + shared.state.textinfo = f""" +

+Loss: {losses.mean():.7f}
+Step: {hypernetwork.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + + checkpoint = sd_models.select_checkpoint() + + hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint_name = checkpoint.model_name + hypernetwork.save(filename) + + return hypernetwork, filename + + diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py new file mode 100644 index 00000000..811bc31e --- /dev/null +++ b/modules/hypernetworks/ui.py @@ -0,0 +1,43 @@ +import html +import os + +import gradio as gr + +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess +from modules import sd_hijack, shared +from modules.hypernetworks import hypernetwork + + +def create_hypernetwork(name): + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet.save(fn) + + shared.reload_hypernetworks() + + return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" + + +def train_hypernetwork(*args): + + initial_hypernetwork = shared.loaded_hypernetwork + + try: + sd_hijack.undo_optimizations() + + hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. +Hypernetwork saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + shared.loaded_hypernetwork = initial_hypernetwork + sd_hijack.apply_optimizations() + diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f873049a..f07ec041 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -37,7 +37,7 @@ def apply_optimizations(): def undo_optimizations(): - from modules.hypernetwork import hypernetwork + from modules.hypernetworks import hypernetwork ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 27e571fc..3349b9c3 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from ldm.util import default from einops import rearrange from modules import shared -from modules.hypernetwork import hypernetwork +from modules.hypernetworks import hypernetwork if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: diff --git a/modules/shared.py b/modules/shared.py index 375e3afb..1dc2ccf2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.sd_models import modules.styles import modules.devices as devices from modules import sd_samplers -from modules.hypernetwork import hypernetwork +from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') diff --git a/modules/ui.py b/modules/ui.py index f57f32db..42e5d866 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -39,7 +39,7 @@ import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui -import modules.hypernetwork.ui +import modules.hypernetworks.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() -- cgit v1.2.3 From b0583be0884cd17dafb408fd79b52b2a0a972563 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 15:54:34 +0300 Subject: more renames --- modules/hypernetworks/ui.py | 4 ++-- modules/ui.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 811bc31e..e7540f41 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -13,7 +13,7 @@ def create_hypernetwork(name): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" - hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name) hypernet.save(fn) shared.reload_hypernetworks() @@ -28,7 +28,7 @@ def train_hypernetwork(*args): try: sd_hijack.undo_optimizations() - hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) + hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args) res = f""" Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. diff --git a/modules/ui.py b/modules/ui.py index 42e5d866..ee333c3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1111,7 +1111,7 @@ def create_ui(wrap_gradio_gpu_call): ) create_hypernetwork.click( - fn=modules.hypernetwork.ui.create_hypernetwork, + fn=modules.hypernetworks.ui.create_hypernetwork, inputs=[ new_hypernetwork_name, ], @@ -1164,7 +1164,7 @@ def create_ui(wrap_gradio_gpu_call): ) train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetwork.ui.train_hypernetwork, extra_outputs=[gr.update()]), + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ train_hypernetwork_name, -- cgit v1.2.3 From d01a2d01560b31937df1f3433d210c18f97d32fa Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Tue, 11 Oct 2022 08:03:31 -0500 Subject: move list refresh to webui.py and add stdout indicating it's doing so --- modules/ui.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 06ff118f..ae9317a3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -39,7 +39,6 @@ import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui -from modules.sd_models import list_models # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -1291,8 +1290,6 @@ Requested path was: {f} shared.state.interrupt() settings_interface.gradio_ref.do_restart = True - # refresh models so that new models/.ckpt's show up on reload - list_models() restart_gradio.click( fn=request_restart, -- cgit v1.2.3 From 66b7d7584f0b44ce1316425808c27ca7df38293c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 17:03:00 +0300 Subject: become even stricter with pickles no pickle shall pass thank you again, RyotaK --- modules/safe.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'modules') diff --git a/modules/safe.py b/modules/safe.py index 05917463..20be16a5 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -10,6 +10,7 @@ import torch import numpy import _codecs import zipfile +import re # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage @@ -54,11 +55,27 @@ class RestrictedUnpickler(pickle.Unpickler): raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") +allowed_zip_names = ["archive/data.pkl", "archive/version"] +allowed_zip_names_re = re.compile(r"^archive/data/\d+$") + + +def check_zip_filenames(filename, names): + for name in names: + if name in allowed_zip_names: + continue + if allowed_zip_names_re.match(name): + continue + + raise Exception(f"bad file inside {filename}: {name}") + + def check_pt(filename): try: # new pytorch format is a zip file with zipfile.ZipFile(filename) as z: + check_zip_filenames(filename, z.namelist()) + with z.open('archive/data.pkl') as file: unpickler = RestrictedUnpickler(file) unpickler.load() -- cgit v1.2.3 From c0484f1b986ce7acb0e3596f6089a191279f5442 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 10 Oct 2022 22:48:54 -0400 Subject: Add cross-attention optimization from InvokeAI * Add cross-attention optimization from InvokeAI (~30% speed improvement on MPS) * Add command line option for it * Make it default when CUDA is unavailable --- modules/sd_hijack.py | 5 ++- modules/sd_hijack_optimizations.py | 79 ++++++++++++++++++++++++++++++++++++++ modules/shared.py | 5 ++- 3 files changed, 86 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f07ec041..5a1b167f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -30,8 +30,11 @@ def apply_optimizations(): elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): + print("Applying cross attention optimization (InvokeAI).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - print("Applying cross attention optimization.") + print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 3349b9c3..870226c5 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,6 +1,7 @@ import math import sys import traceback +import psutil import torch from torch import einsum @@ -116,6 +117,84 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) +# -- From https://github.com/invoke-ai/InvokeAI/blob/main/ldm/modules/attention.py (with hypernetworks support added) -- + +mem_total_gb = psutil.virtual_memory().total // (1 << 30) + +def einsum_op_compvis(q, k, v): + s = einsum('b i d, b j d -> b i j', q, k) + s = s.softmax(dim=-1, dtype=s.dtype) + return einsum('b i j, b j d -> b i d', s, v) + +def einsum_op_slice_0(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], slice_size): + end = i + slice_size + r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + return r + +def einsum_op_slice_1(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) + return r + +def einsum_op_mps_v1(q, k, v): + if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 + return einsum_op_compvis(q, k, v) + else: + slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + return einsum_op_slice_1(q, k, v, slice_size) + +def einsum_op_mps_v2(q, k, v): + if mem_total_gb > 8 and q.shape[1] <= 4096: + return einsum_op_compvis(q, k, v) + else: + return einsum_op_slice_0(q, k, v, 1) + +def einsum_op_tensor_mem(q, k, v, max_tensor_mb): + size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) + if size_mb <= max_tensor_mb: + return einsum_op_compvis(q, k, v) + div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() + if div <= q.shape[0]: + return einsum_op_slice_0(q, k, v, q.shape[0] // div) + return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + +def einsum_op(q, k, v): + if q.device.type == 'mps': + if mem_total_gb >= 32: + return einsum_op_mps_v1(q, k, v) + return einsum_op_mps_v2(q, k, v) + + # Smaller slices are faster due to L2/L3/SLC caches. + # Tested on i7 with 8MB L3 cache. + return einsum_op_tensor_mem(q, k, v, 32) + +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + hypernetwork = shared.loaded_hypernetwork + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k = self.to_k(hypernetwork_layers[0](context)) * self.scale + v = self.to_v(hypernetwork_layers[1](context)) + else: + k = self.to_k(context) * self.scale + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = einsum_op(q, k, v) + return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) + +# -- End of code from https://github.com/invoke-ai/InvokeAI/blob/main/ldm/modules/attention.py -- + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) diff --git a/modules/shared.py b/modules/shared.py index 1dc2ccf2..20b45f23 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -50,9 +50,10 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") -parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") -parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") +parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") +parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") +parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[]) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) -- cgit v1.2.3 From 98fd5cde72d5bda1620ab78416c7828fdc3dc10b Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 10 Oct 2022 23:55:48 -0400 Subject: Add check for psutil --- modules/sd_hijack.py | 10 ++++++++-- modules/sd_hijack_optimizations.py | 19 +++++++++++++++---- 2 files changed, 23 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5a1b167f..ac70f876 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -10,6 +10,7 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules.shared import opts, device, cmd_opts +from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -31,8 +32,13 @@ def apply_optimizations(): print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): - print("Applying cross attention optimization (InvokeAI).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI + if not invokeAI_mps_available and shared.device.type == 'mps': + print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") + print("Applying v1 cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + else: + print("Applying cross attention optimization (InvokeAI).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 870226c5..2a4ac7e0 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,7 @@ import math import sys import traceback -import psutil +import importlib import torch from torch import einsum @@ -117,9 +117,20 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -# -- From https://github.com/invoke-ai/InvokeAI/blob/main/ldm/modules/attention.py (with hypernetworks support added) -- -mem_total_gb = psutil.virtual_memory().total // (1 << 30) +def check_for_psutil(): + try: + spec = importlib.util.find_spec('psutil') + return spec is not None + except ModuleNotFoundError: + return False + +invokeAI_mps_available = check_for_psutil() + +# -- Taken from https://github.com/invoke-ai/InvokeAI -- +if invokeAI_mps_available: + import psutil + mem_total_gb = psutil.virtual_memory().total // (1 << 30) def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) @@ -193,7 +204,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): r = einsum_op(q, k, v) return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) -# -- End of code from https://github.com/invoke-ai/InvokeAI/blob/main/ldm/modules/attention.py -- +# -- End of code from https://github.com/invoke-ai/InvokeAI -- def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads -- cgit v1.2.3 From 574c8e554a5371eca2cbf344764cb241c6ec4efc Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 11 Oct 2022 03:32:11 -0400 Subject: Add InvokeAI and lstein to credits, add back CUDA support --- modules/sd_hijack_optimizations.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 2a4ac7e0..f006427f 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -173,7 +173,20 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb): return einsum_op_slice_0(q, k, v, q.shape[0] // div) return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) +def einsum_op_cuda(q, k, v): + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + # Divide factor of safety as there's copying and fragmentation + return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + def einsum_op(q, k, v): + if q.device.type == 'cuda': + return einsum_op_cuda(q, k, v) + if q.device.type == 'mps': if mem_total_gb >= 32: return einsum_op_mps_v1(q, k, v) -- cgit v1.2.3 From 861db783c7acfcb93cf0b5191db3d50f9a9bc531 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 11 Oct 2022 05:13:17 -0400 Subject: Use apply_hypernetwork function --- modules/sd_hijack_optimizations.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index f006427f..79405525 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -202,16 +202,10 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k = self.to_k(hypernetwork_layers[0](context)) * self.scale - v = self.to_v(hypernetwork_layers[1](context)) - else: - k = self.to_k(context) * self.scale - v = self.to_v(context) - del context, x + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) * self.scale + v = self.to_v(context_v) + del context, context_k, context_v, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) r = einsum_op(q, k, v) -- cgit v1.2.3 From d682444ecc99319fbd2b142a12727501e2884ba7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 18:04:47 +0300 Subject: add option to select hypernetwork modules when creating --- modules/hypernetworks/hypernetwork.py | 4 ++-- modules/hypernetworks/ui.py | 4 ++-- modules/ui.py | 2 ++ 3 files changed, 6 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index aa701bda..b081f14e 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -42,7 +42,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None): + def __init__(self, name=None, enable_sizes=None): self.filename = None self.name = name self.layers = {} @@ -50,7 +50,7 @@ class Hypernetwork: self.sd_checkpoint = None self.sd_checkpoint_name = None - for size in [320, 640, 768, 1280]: + for size in enable_sizes or [320, 640, 768, 1280]: self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) def weights(self): diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index e7540f41..cdddcce1 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,11 +9,11 @@ from modules import sd_hijack, shared from modules.hypernetworks import hypernetwork -def create_hypernetwork(name): +def create_hypernetwork(name, enable_sizes): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" - hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name) + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes]) hypernet.save(fn) shared.reload_hypernetworks() diff --git a/modules/ui.py b/modules/ui.py index f2d16b12..14b87b92 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1037,6 +1037,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="

Create a new hypernetwork

") new_hypernetwork_name = gr.Textbox(label="Name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) with gr.Row(): with gr.Column(scale=3): @@ -1114,6 +1115,7 @@ def create_ui(wrap_gradio_gpu_call): fn=modules.hypernetworks.ui.create_hypernetwork, inputs=[ new_hypernetwork_name, + new_hypernetwork_sizes, ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From ff4ef13dd591ec52f196f344f47537695df95364 Mon Sep 17 00:00:00 2001 From: JC_Array Date: Tue, 11 Oct 2022 10:24:27 -0500 Subject: removed unneeded print --- modules/deepbooru.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index e31e92c0..89dcac3c 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -119,7 +119,6 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort) # sort tags result_tags_out = [] sort_ndx = 0 - print(alpha_sort) if alpha_sort: sort_ndx = 1 -- cgit v1.2.3 From 6d09b8d1df3a96e1380bb1650f5961781630af96 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 18:33:57 +0300 Subject: produce error when training with medvram/lowvram enabled --- modules/hypernetworks/ui.py | 2 ++ modules/textual_inversion/ui.py | 3 +++ 2 files changed, 5 insertions(+) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index cdddcce1..3541a388 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -25,6 +25,8 @@ def train_hypernetwork(*args): initial_hypernetwork = shared.loaded_hypernetwork + assert not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram, 'Training models with lowvram or medvram is not possible' + try: sd_hijack.undo_optimizations() diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index c57de1f9..70f47343 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -22,6 +22,9 @@ def preprocess(*args): def train_embedding(*args): + + assert not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram, 'Training models with lowvram or medvram is not possible' + try: sd_hijack.undo_optimizations() -- cgit v1.2.3 From d4ea5f4d8631f778d11efcde397e4a5b8801d43b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 19:03:08 +0300 Subject: add an option to unload models during hypernetwork training to save VRAM --- modules/hypernetworks/hypernetwork.py | 25 +++++++++++++++------- modules/hypernetworks/ui.py | 4 +++- modules/shared.py | 4 ++++ modules/textual_inversion/dataset.py | 29 ++++++++++++++++++-------- modules/textual_inversion/textual_inversion.py | 2 +- 5 files changed, 46 insertions(+), 18 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b081f14e..4700e1ec 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -175,6 +175,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) + unload = shared.opts.unload_models_when_training if save_hypernetwork_every > 0: hypernetwork_dir = os.path.join(log_directory, "hypernetworks") @@ -188,11 +189,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, else: images_dir = None - cond_model = shared.sd_model.cond_stage_model - shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) hypernetwork = shared.loaded_hypernetwork weights = hypernetwork.weights() @@ -211,7 +214,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, return hypernetwork, filename pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text) in pbar: + for i, (x, text, cond) in pbar: hypernetwork.step = i + ititial_step if hypernetwork.step > steps: @@ -221,11 +224,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, break with torch.autocast("cuda"): - c = cond_model([text]) - + cond = cond.to(devices.device) x = x.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] + loss = shared.sd_model(x.unsqueeze(0), cond)[0] del x + del cond losses[hypernetwork.step % losses.shape[0]] = loss.item() @@ -244,6 +247,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, preview_text = text if preview_image_prompt == "" else preview_image_prompt + optimizer.zero_grad() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, prompt=preview_text, @@ -255,6 +262,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, processed = processing.process_images(p) image = processed.images[0] + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + shared.state.current_image = image image.save(last_saved_image) diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 3541a388..c67facbb 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -5,7 +5,7 @@ import gradio as gr import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess -from modules import sd_hijack, shared +from modules import sd_hijack, shared, devices from modules.hypernetworks import hypernetwork @@ -41,5 +41,7 @@ Hypernetwork saved to {html.escape(filename)} raise finally: shared.loaded_hypernetwork = initial_hypernetwork + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) sd_hijack.apply_optimizations() diff --git a/modules/shared.py b/modules/shared.py index 20b45f23..c1092ff7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -228,6 +228,10 @@ options_templates.update(options_section(('system', "System"), { "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), })) +options_templates.update(options_section(('training', "Training"), { + "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"), +})) + options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 4d006366..f61f40d3 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -8,14 +8,14 @@ from torchvision import transforms import random import tqdm -from modules import devices +from modules import devices, shared import re re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): self.placeholder_token = placeholder_token @@ -32,6 +32,8 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' + cond_model = shared.sd_model.cond_stage_model + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): @@ -53,7 +55,13 @@ class PersonalizedBase(Dataset): init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = init_latent.to(devices.cpu) - self.dataset.append((init_latent, filename_tokens)) + if include_cond: + text = self.create_text(filename_tokens) + cond = cond_model([text]).to(devices.cpu) + else: + cond = None + + self.dataset.append((init_latent, filename_tokens, cond)) self.length = len(self.dataset) * repeats @@ -64,6 +72,12 @@ class PersonalizedBase(Dataset): def shuffle(self): self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + def create_text(self, filename_tokens): + text = random.choice(self.lines) + text = text.replace("[name]", self.placeholder_token) + text = text.replace("[filewords]", ' '.join(filename_tokens)) + return text + def __len__(self): return self.length @@ -72,10 +86,7 @@ class PersonalizedBase(Dataset): self.shuffle() index = self.indexes[i % len(self.indexes)] - x, filename_tokens = self.dataset[index] - - text = random.choice(self.lines) - text = text.replace("[name]", self.placeholder_token) - text = text.replace("[filewords]", ' '.join(filename_tokens)) + x, filename_tokens, cond = self.dataset[index] - return x, text + text = self.create_text(filename_tokens) + return x, text, cond diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index bb05cdc6..35f4bd9e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -201,7 +201,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini return embedding, filename pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, (x, text) in pbar: + for i, (x, text, _) in pbar: embedding.step = i + ititial_step if embedding.step > steps: -- cgit v1.2.3 From 6a9ea5b41cf92cd9e980349bb5034439f4e7a58b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 19:22:30 +0300 Subject: prevent extra modules from being saved/loaded with hypernet --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4700e1ec..5608e799 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -50,7 +50,7 @@ class Hypernetwork: self.sd_checkpoint = None self.sd_checkpoint_name = None - for size in enable_sizes or [320, 640, 768, 1280]: + for size in enable_sizes or []: self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) def weights(self): -- cgit v1.2.3 From c080f52ceae73b893155eff7de577aaf1a982a2f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 19:37:58 +0100 Subject: move embedding logic to separate file --- modules/textual_inversion/image_embedding.py | 234 +++++++++++++++++++++++++++ 1 file changed, 234 insertions(+) create mode 100644 modules/textual_inversion/image_embedding.py (limited to 'modules') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py new file mode 100644 index 00000000..6ad39602 --- /dev/null +++ b/modules/textual_inversion/image_embedding.py @@ -0,0 +1,234 @@ +import base64 +import json +import numpy as np +import zlib +from PIL import Image,PngImagePlugin,ImageDraw,ImageFont +from fonts.ttf import Roboto +import torch + +class EmbeddingEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, torch.Tensor): + return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} + return json.JSONEncoder.default(self, obj) + +class EmbeddingDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + def object_hook(self, d): + if 'TORCHTENSOR' in d: + return torch.from_numpy(np.array(d['TORCHTENSOR'])) + return d + +def embedding_to_b64(data): + d = json.dumps(data,cls=EmbeddingEncoder) + return base64.b64encode(d.encode()) + +def embedding_from_b64(data): + d = base64.b64decode(data) + return json.loads(d,cls=EmbeddingDecoder) + +def lcg(m=2**32, a=1664525, c=1013904223, seed=0): + while True: + seed = (a * seed + c) % m + yield seed%255 + +def xor_block(block): + g = lcg() + randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape) + return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F) + +def style_block(block,sequence): + im = Image.new('RGB',(block.shape[1],block.shape[0])) + draw = ImageDraw.Draw(im) + i=0 + for x in range(-6,im.size[0],8): + for yi,y in enumerate(range(-6,im.size[1],8)): + offset=0 + if yi%2==0: + offset=4 + shade = sequence[i%len(sequence)] + i+=1 + draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) ) + + fg = np.array(im).astype(np.uint8) & 0xF0 + + return block ^ fg + +def insert_image_data_embed(image,data): + d = 3 + data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9) + data_np_ = np.frombuffer(data_compressed,np.uint8).copy() + data_np_high = data_np_ >> 4 + data_np_low = data_np_ & 0x0F + + h = image.size[1] + next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0]%h)) + next_size = next_size + ((h*d)-(next_size%(h*d))) + + data_np_low.resize(next_size) + data_np_low = data_np_low.reshape((h,-1,d)) + + data_np_high.resize(next_size) + data_np_high = data_np_high.reshape((h,-1,d)) + + edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] + edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8) + + data_np_low = style_block(data_np_low,sequence=edge_style) + data_np_low = xor_block(data_np_low) + data_np_high = style_block(data_np_high,sequence=edge_style[::-1]) + data_np_high = xor_block(data_np_high) + + im_low = Image.fromarray(data_np_low,mode='RGB') + im_high = Image.fromarray(data_np_high,mode='RGB') + + background = Image.new('RGB',(image.size[0]+im_low.size[0]+im_high.size[0]+2,image.size[1]),(0,0,0)) + background.paste(im_low,(0,0)) + background.paste(image,(im_low.size[0]+1,0)) + background.paste(im_high,(im_low.size[0]+1+image.size[0]+1,0)) + + return background + +def crop_black(img,tol=0): + mask = (img>tol).all(2) + mask0,mask1 = mask.any(0),mask.any(1) + col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax() + row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax() + return img[row_start:row_end,col_start:col_end] + +def extract_image_data_embed(image): + d=3 + outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F + black_cols = np.where( np.sum(outarr, axis=(0,2))==0) + if black_cols[0].shape[0] < 2: + print('No Image data blocks found.') + return None + + data_block_lower = outarr[:,:black_cols[0].min(),:].astype(np.uint8) + data_block_upper = outarr[:,black_cols[0].max()+1:,:].astype(np.uint8) + + data_block_lower = xor_block(data_block_lower) + data_block_upper = xor_block(data_block_upper) + + data_block = (data_block_upper << 4) | (data_block_lower) + data_block = data_block.flatten().tobytes() + + data = zlib.decompress(data_block) + return json.loads(data,cls=EmbeddingDecoder) + +def addCaptionLines(lines,image,initialx,textfont): + draw = ImageDraw.Draw(image) + hstart =initialx + for fill,line in lines: + fontsize = 32 + font = ImageFont.truetype(textfont, fontsize) + _,_,w, h = draw.textbbox((0,0),line,font=font) + fontsize = min( int(fontsize * ((image.size[0]-35)/w) ), 28) + font = ImageFont.truetype(textfont, fontsize) + _,_,w,h = draw.textbbox((0,0),line,font=font) + draw.text(((image.size[0]-w)/2,hstart), line, font=font, fill=fill) + hstart += h + return hstart + +def caption_image(image,prelines,postlines,background=(51, 51, 51),font=None): + if font is None: + try: + font = ImageFont.truetype(opts.font or Roboto, fontsize) + font = opts.font or Roboto + except Exception: + font = Roboto + + sample_image = image + background = Image.new("RGBA", (sample_image.size[0],sample_image.size[1]+1024), background) + hoffset = addCaptionLines(prelines,background,5,font)+16 + background.paste(sample_image,(0,hoffset)) + hoffset = hoffset+sample_image.size[1]+8 + hoffset = addCaptionLines(postlines,background,hoffset,font) + background = background.crop((0,0,sample_image.size[0],hoffset+8)) + return background + +def caption_image_overlay(srcimage,title,footerLeft,footerMid,footerRight,textfont=None): + from math import cos + + image = srcimage.copy() + + if textfont is None: + try: + textfont = ImageFont.truetype(opts.font or Roboto, fontsize) + textfont = opts.font or Roboto + except Exception: + textfont = Roboto + + factor = 1.5 + gradient = Image.new('RGBA', (1,image.size[1]), color=(0,0,0,0)) + for y in range(image.size[1]): + mag = 1-cos(y/image.size[1]*factor) + mag = max(mag,1-cos((image.size[1]-y)/image.size[1]*factor*1.1)) + gradient.putpixel((0, y), (0,0,0,int(mag*255))) + image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) + + draw = ImageDraw.Draw(image) + fontsize = 32 + font = ImageFont.truetype(textfont, fontsize) + padding = 10 + + _,_,w, h = draw.textbbox((0,0),title,font=font) + fontsize = min( int(fontsize * (((image.size[0]*0.75)-(padding*4))/w) ), 72) + font = ImageFont.truetype(textfont, fontsize) + _,_,w,h = draw.textbbox((0,0),title,font=font) + draw.text((padding,padding), title, anchor='lt', font=font, fill=(255,255,255,230)) + + _,_,w, h = draw.textbbox((0,0),footerLeft,font=font) + fontsize_left = min( int(fontsize * (((image.size[0]/3)-(padding))/w) ), 72) + _,_,w, h = draw.textbbox((0,0),footerMid,font=font) + fontsize_mid = min( int(fontsize * (((image.size[0]/3)-(padding))/w) ), 72) + _,_,w, h = draw.textbbox((0,0),footerRight,font=font) + fontsize_right = min( int(fontsize * (((image.size[0]/3)-(padding))/w) ), 72) + + font = ImageFont.truetype(textfont, min(fontsize_left,fontsize_mid,fontsize_right)) + + draw.text((padding,image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255,255,255,230)) + draw.text((image.size[0]/2,image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255,255,255,230)) + draw.text((image.size[0]-padding,image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255,255,255,230)) + + return image + +if __name__ == '__main__': + + image = Image.new('RGBA',(512,512),(255,255,200,255)) + caption_image(image,[((255,255,255),'line a'),((255,255,255),'line b')], + [((255,255,255),'line c'),((255,255,255),'line d')]) + + image = Image.new('RGBA',(512,512),(255,255,200,255)) + cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight') + + test_embed = {'string_to_param':{'*':torch.from_numpy(np.random.random((2, 4096)))}} + + embedded_image = insert_image_data_embed(cap_image, test_embed) + + retrived_embed = extract_image_data_embed(embedded_image) + + assert str(retrived_embed) == str(test_embed) + + embedded_image2 = insert_image_data_embed(cap_image, retrived_embed) + + assert embedded_image == embedded_image2 + + g = lcg() + shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist() + + reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177, + 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179, + 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193, + 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28, + 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0, + 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185, + 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82, + 204, 86, 73, 222, 44, 198, 118, 240, 97] + + assert shared_random == reference_random + + hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist()) + + assert 12731374 == hunna_kay_random_sum \ No newline at end of file -- cgit v1.2.3 From e5fbf5c755b7c306696546405385d5d2314e555b Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 19:46:33 +0100 Subject: remove embedding related image functions from images --- modules/images.py | 77 ------------------------------------------------------- 1 file changed, 77 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index e62eec8e..c0a90676 100644 --- a/modules/images.py +++ b/modules/images.py @@ -463,80 +463,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i txt_fullfn = None return fullfn, txt_fullfn - -def addCaptionLines(lines,image,initialx,textfont): - draw = ImageDraw.Draw(image) - hstart =initialx - for fill,line in lines: - fontSize = 32 - font = ImageFont.truetype(textfont, fontSize) - _,_,w, h = draw.textbbox((0,0),line,font=font) - fontSize = min( int(fontSize * ((image.size[0]-35)/w) ), 28) - font = ImageFont.truetype(textfont, fontSize) - _,_,w,h = draw.textbbox((0,0),line,font=font) - draw.text(((image.size[0]-w)/2,hstart), line, font=font, fill=fill) - hstart += h - return hstart - -def captionImge(image,prelines,postlines,background=(51, 51, 51),font=None): - if font is None: - try: - font = ImageFont.truetype(opts.font or Roboto, fontsize) - font = opts.font or Roboto - except Exception: - font = Roboto - - sampleImage = image - background = Image.new("RGBA", (sampleImage.size[0],sampleImage.size[1]+1024), background) - hoffset = addCaptionLines(prelines,background,5,font)+16 - background.paste(sampleImage,(0,hoffset)) - hoffset = hoffset+sampleImage.size[1]+8 - hoffset = addCaptionLines(postlines,background,hoffset,font) - background = background.crop((0,0,sampleImage.size[0],hoffset+8)) - return background - -def captionImageOverlay(srcimage,title,footerLeft,footerMid,footerRight,textfont=None): - from math import cos - - image = srcimage.copy() - - if textfont is None: - try: - textfont = ImageFont.truetype(opts.font or Roboto, fontsize) - textfont = opts.font or Roboto - except Exception: - textfont = Roboto - - factor = 1.5 - gradient = Image.new('RGBA', (1,image.size[1]), color=(0,0,0,0)) - for y in range(image.size[1]): - mag = 1-cos(y/image.size[1]*factor) - mag = max(mag,1-cos((image.size[1]-y)/image.size[1]*factor*1.1)) - gradient.putpixel((0, y), (0,0,0,int(mag*255))) - image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) - - draw = ImageDraw.Draw(image) - fontSize = 32 - font = ImageFont.truetype(textfont, fontSize) - padding = 10 - - _,_,w, h = draw.textbbox((0,0),title,font=font) - fontSize = min( int(fontSize * (((image.size[0]*0.75)-(padding*4))/w) ), 72) - font = ImageFont.truetype(textfont, fontSize) - _,_,w,h = draw.textbbox((0,0),title,font=font) - draw.text((padding,padding), title, anchor='lt', font=font, fill=(255,255,255,230)) - - _,_,w, h = draw.textbbox((0,0),footerLeft,font=font) - fontSizeleft = min( int(fontSize * (((image.size[0]/3)-(padding))/w) ), 72) - _,_,w, h = draw.textbbox((0,0),footerMid,font=font) - fontSizemid = min( int(fontSize * (((image.size[0]/3)-(padding))/w) ), 72) - _,_,w, h = draw.textbbox((0,0),footerRight,font=font) - fontSizeright = min( int(fontSize * (((image.size[0]/3)-(padding))/w) ), 72) - - font = ImageFont.truetype(textfont, min(fontSizeleft,fontSizemid,fontSizeright)) - - draw.text((padding,image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255,255,255,230)) - draw.text((image.size[0]/2,image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255,255,255,230)) - draw.text((image.size[0]-padding,image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255,255,255,230)) - - return image -- cgit v1.2.3 From 61788c0538415fa9ca1dd1b306519c116b18bd2c Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 19:50:50 +0100 Subject: shift embedding logic out of textual_inversion --- modules/textual_inversion/textual_inversion.py | 125 ++----------------------- 1 file changed, 6 insertions(+), 119 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8c66aeb5..22b4ae7f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,124 +7,11 @@ import tqdm import html import datetime -from PIL import Image,PngImagePlugin,ImageDraw -from ..images import captionImageOverlay -import numpy as np -import base64 -import json -import zlib +from PIL import Image,PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset -class EmbeddingEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, torch.Tensor): - return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} - return json.JSONEncoder.default(self, obj) - -class EmbeddingDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): - json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) - def object_hook(self, d): - if 'TORCHTENSOR' in d: - return torch.from_numpy(np.array(d['TORCHTENSOR'])) - return d - -def embeddingToB64(data): - d = json.dumps(data,cls=EmbeddingEncoder) - return base64.b64encode(d.encode()) - -def embeddingFromB64(data): - d = base64.b64decode(data) - return json.loads(d,cls=EmbeddingDecoder) - -def lcg(m=2**32, a=1664525, c=1013904223, seed=0): - while True: - seed = (a * seed + c) % m - yield seed - -def xorBlock(block): - g = lcg() - randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape) - return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F) - -def styleBlock(block,sequence): - im = Image.new('RGB',(block.shape[1],block.shape[0])) - draw = ImageDraw.Draw(im) - i=0 - for x in range(-6,im.size[0],8): - for yi,y in enumerate(range(-6,im.size[1],8)): - offset=0 - if yi%2==0: - offset=4 - shade = sequence[i%len(sequence)] - i+=1 - draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) ) - - fg = np.array(im).astype(np.uint8) & 0xF0 - return block ^ fg - -def insertImageDataEmbed(image,data): - d = 3 - data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9) - dnp = np.frombuffer(data_compressed,np.uint8).copy() - dnphigh = dnp >> 4 - dnplow = dnp & 0x0F - - h = image.size[1] - next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h)) - next_size = next_size + ((h*d)-(next_size%(h*d))) - - dnplow.resize(next_size) - dnplow = dnplow.reshape((h,-1,d)) - - dnphigh.resize(next_size) - dnphigh = dnphigh.reshape((h,-1,d)) - - edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] - edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8) - - dnplow = styleBlock(dnplow,sequence=edgeStyleWeights) - dnplow = xorBlock(dnplow) - dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1]) - dnphigh = xorBlock(dnphigh) - - imlow = Image.fromarray(dnplow,mode='RGB') - imhigh = Image.fromarray(dnphigh,mode='RGB') - - background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0)) - background.paste(imlow,(0,0)) - background.paste(image,(imlow.size[0]+1,0)) - background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0)) - - return background - -def crop_black(img,tol=0): - mask = (img>tol).all(2) - mask0,mask1 = mask.any(0),mask.any(1) - col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax() - row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax() - return img[row_start:row_end,col_start:col_end] - -def extractImageDataEmbed(image): - d=3 - outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F - blackCols = np.where( np.sum(outarr, axis=(0,2))==0) - if blackCols[0].shape[0] < 2: - print('No Image data blocks found.') - return None - - dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8) - dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8) - - dataBlocklower = xorBlock(dataBlocklower) - dataBlockupper = xorBlock(dataBlockupper) - - dataBlock = (dataBlockupper << 4) | (dataBlocklower) - dataBlock = dataBlock.flatten().tobytes() - data = zlib.decompress(dataBlock) - return json.loads(data,cls=EmbeddingDecoder) class Embedding: def __init__(self, vec, name, step=None): @@ -199,10 +86,10 @@ class EmbeddingDatabase: if filename.upper().endswith('.PNG'): embed_image = Image.open(path) if 'sd-ti-embedding' in embed_image.text: - data = embeddingFromB64(embed_image.text['sd-ti-embedding']) + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) name = data.get('name',name) else: - data = extractImageDataEmbed(embed_image) + data = extract_image_data_embed(embed_image) name = data.get('name',name) else: data = torch.load(path, map_location="cpu") @@ -393,7 +280,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embeddingToB64(data)) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) title = "<{}>".format(data.get('name','???')) checkpoint = sd_models.select_checkpoint() @@ -401,8 +288,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini footer_mid = '[{}]'.format(checkpoint.hash) footer_right = '{}'.format(embedding.step) - captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right) - captioned_image = insertImageDataEmbed(captioned_image,data) + captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right) + captioned_image = insert_image_data_embed(captioned_image,data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) -- cgit v1.2.3 From db71290d2659d3b58ff9b57a82e4721a9eab9229 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 19:55:54 +0100 Subject: remove old caption method --- modules/textual_inversion/image_embedding.py | 39 ++-------------------------- 1 file changed, 2 insertions(+), 37 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index 6ad39602..c67028a5 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -117,37 +117,6 @@ def extract_image_data_embed(image): data = zlib.decompress(data_block) return json.loads(data,cls=EmbeddingDecoder) -def addCaptionLines(lines,image,initialx,textfont): - draw = ImageDraw.Draw(image) - hstart =initialx - for fill,line in lines: - fontsize = 32 - font = ImageFont.truetype(textfont, fontsize) - _,_,w, h = draw.textbbox((0,0),line,font=font) - fontsize = min( int(fontsize * ((image.size[0]-35)/w) ), 28) - font = ImageFont.truetype(textfont, fontsize) - _,_,w,h = draw.textbbox((0,0),line,font=font) - draw.text(((image.size[0]-w)/2,hstart), line, font=font, fill=fill) - hstart += h - return hstart - -def caption_image(image,prelines,postlines,background=(51, 51, 51),font=None): - if font is None: - try: - font = ImageFont.truetype(opts.font or Roboto, fontsize) - font = opts.font or Roboto - except Exception: - font = Roboto - - sample_image = image - background = Image.new("RGBA", (sample_image.size[0],sample_image.size[1]+1024), background) - hoffset = addCaptionLines(prelines,background,5,font)+16 - background.paste(sample_image,(0,hoffset)) - hoffset = hoffset+sample_image.size[1]+8 - hoffset = addCaptionLines(postlines,background,hoffset,font) - background = background.crop((0,0,sample_image.size[0],hoffset+8)) - return background - def caption_image_overlay(srcimage,title,footerLeft,footerMid,footerRight,textfont=None): from math import cos @@ -195,11 +164,7 @@ def caption_image_overlay(srcimage,title,footerLeft,footerMid,footerRight,textfo return image if __name__ == '__main__': - - image = Image.new('RGBA',(512,512),(255,255,200,255)) - caption_image(image,[((255,255,255),'line a'),((255,255,255),'line b')], - [((255,255,255),'line c'),((255,255,255),'line d')]) - + image = Image.new('RGBA',(512,512),(255,255,200,255)) cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight') @@ -231,4 +196,4 @@ if __name__ == '__main__': hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist()) - assert 12731374 == hunna_kay_random_sum \ No newline at end of file + assert 12731374 == hunna_kay_random_sum -- cgit v1.2.3 From d6fcc6b87bc00fcdecea276fe5b7c7945f7a8b14 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 22:03:05 +0300 Subject: apply lr schedule to hypernets --- modules/hypernetworks/hypernetwork.py | 19 ++++++++--- modules/textual_inversion/learn_schedule.py | 34 ++++++++++++++++++++ modules/textual_inversion/textual_inversion.py | 44 +++----------------------- modules/ui.py | 2 +- 4 files changed, 54 insertions(+), 45 deletions(-) create mode 100644 modules/textual_inversion/learn_schedule.py (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 5608e799..470659df 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -14,6 +14,7 @@ import torch from torch import einsum from einops import rearrange, repeat import modules.textual_inversion.dataset +from modules.textual_inversion.learn_schedule import LearnSchedule class HypernetworkModule(torch.nn.Module): @@ -202,8 +203,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, for weight in weights: weight.requires_grad = True - optimizer = torch.optim.AdamW(weights, lr=learn_rate) - losses = torch.zeros((32,)) last_saved_file = "" @@ -213,12 +212,24 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps: return hypernetwork, filename + schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) + (learn_rate, end_step) = next(schedules) + print(f'Training at rate of {learn_rate} until step {end_step}') + + optimizer = torch.optim.AdamW(weights, lr=learn_rate) + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, (x, text, cond) in pbar: hypernetwork.step = i + ititial_step - if hypernetwork.step > steps: - break + if hypernetwork.step > end_step: + try: + (learn_rate, end_step) = next(schedules) + except Exception: + break + tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') + for pg in optimizer.param_groups: + pg['lr'] = learn_rate if shared.state.interrupted: break diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py new file mode 100644 index 00000000..db720271 --- /dev/null +++ b/modules/textual_inversion/learn_schedule.py @@ -0,0 +1,34 @@ + +class LearnSchedule: + def __init__(self, learn_rate, max_steps, cur_step=0): + pairs = learn_rate.split(',') + self.rates = [] + self.it = 0 + self.maxit = 0 + for i, pair in enumerate(pairs): + tmp = pair.split(':') + if len(tmp) == 2: + step = int(tmp[1]) + if step > cur_step: + self.rates.append((float(tmp[0]), min(step, max_steps))) + self.maxit += 1 + if step > max_steps: + return + elif step == -1: + self.rates.append((float(tmp[0]), max_steps)) + self.maxit += 1 + return + else: + self.rates.append((float(tmp[0]), max_steps)) + self.maxit += 1 + return + + def __iter__(self): + return self + + def __next__(self): + if self.it < self.maxit: + self.it += 1 + return self.rates[self.it - 1] + else: + raise StopIteration diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 47a27faf..7717837d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -10,6 +10,7 @@ import datetime from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset +from modules.textual_inversion.learn_schedule import LearnSchedule class Embedding: @@ -198,11 +199,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if ititial_step > steps: return embedding, filename - tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]) - epoch_len = (tr_img_len * num_repeats) + tr_img_len - - scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step)) - (learn_rate, end_step) = next(scheduleIter) + schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) + (learn_rate, end_step) = next(schedules) print(f'Training at rate of {learn_rate} until step {end_step}') optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) @@ -213,7 +211,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > end_step: try: - (learn_rate, end_step) = next(scheduleIter) + (learn_rate, end_step) = next(schedules) except: break tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') @@ -288,37 +286,3 @@ Last saved image: {html.escape(last_saved_image)}
embedding.save(filename) return embedding, filename - -class LearnSchedule: - def __init__(self, learn_rate, max_steps, cur_step=0): - pairs = learn_rate.split(',') - self.rates = [] - self.it = 0 - self.maxit = 0 - for i, pair in enumerate(pairs): - tmp = pair.split(':') - if len(tmp) == 2: - step = int(tmp[1]) - if step > cur_step: - self.rates.append((float(tmp[0]), min(step, max_steps))) - self.maxit += 1 - if step > max_steps: - return - elif step == -1: - self.rates.append((float(tmp[0]), max_steps)) - self.maxit += 1 - return - else: - self.rates.append((float(tmp[0]), max_steps)) - self.maxit += 1 - return - - def __iter__(self): - return self - - def __next__(self): - if self.it < self.maxit: - self.it += 1 - return self.rates[self.it - 1] - else: - raise StopIteration diff --git a/modules/ui.py b/modules/ui.py index 2b688e32..1204eef7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1070,7 +1070,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) - learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03") + learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) -- cgit v1.2.3 From aa75d5cfe8c84768b0f5d16f977ddba298677379 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:06:13 +0100 Subject: correct conflict resolution typo --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 22b4ae7f..789383ce 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -169,7 +169,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt) +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." -- cgit v1.2.3 From 91d7ee0d097a7ea203d261b570cd2b834837d9e2 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:09:10 +0100 Subject: update imports --- modules/textual_inversion/textual_inversion.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 789383ce..ff0a62b3 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,6 +12,9 @@ from PIL import Image,PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset +from modules.textual_inversion.image_embedding import( embedding_to_b64,embedding_from_b64, + insert_image_data_embed,extract_image_data_embed, + caption_image_overlay ) class Embedding: def __init__(self, vec, name, step=None): -- cgit v1.2.3 From 5f3317376bb7952bc5145f05f16c1bbd466efc85 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:09:49 +0100 Subject: spacing --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ff0a62b3..485ef46c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,7 +12,7 @@ from PIL import Image,PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset -from modules.textual_inversion.image_embedding import( embedding_to_b64,embedding_from_b64, +from modules.textual_inversion.image_embedding import (embedding_to_b64,embedding_from_b64, insert_image_data_embed,extract_image_data_embed, caption_image_overlay ) -- cgit v1.2.3 From 7e6a6e00ad6f3b7ef43c8120db9ecac6e8d6bea5 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:20:46 +0100 Subject: Add files via upload --- modules/textual_inversion/test_embedding.png | Bin 0 -> 489220 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 modules/textual_inversion/test_embedding.png (limited to 'modules') diff --git a/modules/textual_inversion/test_embedding.png b/modules/textual_inversion/test_embedding.png new file mode 100644 index 00000000..07e2d9af Binary files /dev/null and b/modules/textual_inversion/test_embedding.png differ -- cgit v1.2.3 From 66ec505975aaa305a217fc27281ce368cbaef281 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:21:30 +0100 Subject: add file based test --- modules/textual_inversion/image_embedding.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'modules') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index c67028a5..1224fb42 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -164,6 +164,14 @@ def caption_image_overlay(srcimage,title,footerLeft,footerMid,footerRight,textfo return image if __name__ == '__main__': + + testEmbed = Image.open('test_embedding.png') + + data = extract_image_data_embed(testEmbed) + assert data is not None + + data = embedding_from_b64(testEmbed.text['sd-ti-embedding']) + assert data is not None image = Image.new('RGBA',(512,512),(255,255,200,255)) cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight') -- cgit v1.2.3 From 6be32b31d181e42c639dad3451229aa7b9cfd1cf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 23:07:09 +0300 Subject: reports that training with medvram is possible. --- modules/hypernetworks/ui.py | 2 +- modules/textual_inversion/ui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index c67facbb..dfa599af 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -25,7 +25,7 @@ def train_hypernetwork(*args): initial_hypernetwork = shared.loaded_hypernetwork - assert not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram, 'Training models with lowvram or medvram is not possible' + assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' try: sd_hijack.undo_optimizations() diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 70f47343..36881e7a 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -23,7 +23,7 @@ def preprocess(*args): def train_embedding(*args): - assert not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram, 'Training models with lowvram or medvram is not possible' + assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' try: sd_hijack.undo_optimizations() -- cgit v1.2.3 From f53f703aebc801c4204182d52bb1e0bef9808e1f Mon Sep 17 00:00:00 2001 From: JC_Array Date: Tue, 11 Oct 2022 18:12:12 -0500 Subject: resolved conflicts, moved settings under interrogate section, settings only show if deepbooru flag is enabled --- modules/deepbooru.py | 2 +- modules/shared.py | 19 +++++++++---------- modules/textual_inversion/preprocess.py | 2 +- modules/ui.py | 2 +- 4 files changed, 12 insertions(+), 13 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 89dcac3c..29529949 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -8,7 +8,7 @@ def get_deepbooru_tags(pil_image): This method is for running only one image at a time for simple use. Used to the img2img interrogate. """ from modules import shared # prevents circular reference - create_deepbooru_process(shared.opts.deepbooru_threshold, shared.opts.deepbooru_sort_alpha) + create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha) shared.deepbooru_process_return["value"] = -1 shared.deepbooru_process_queue.put(pil_image) while shared.deepbooru_process_return["value"] == -1: diff --git a/modules/shared.py b/modules/shared.py index 817203f8..5456c477 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -248,15 +248,20 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -options_templates.update(options_section(('interrogate', "Interrogate Options"), { +interrogate_option_dictionary = { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), - "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"), - "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), -})) + "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)") +} + +if cmd_opts.deepdanbooru: + interrogate_option_dictionary["interrogate_deepbooru_score_threshold"] = OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}) + interrogate_option_dictionary["deepbooru_sort_alpha"] = OptionInfo(True, "Interrogate: deepbooru sort alphabetically", gr.Checkbox) + +options_templates.update(options_section(('interrogate', "Interrogate Options"), interrogate_option_dictionary)) options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), @@ -282,12 +287,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) -if cmd_opts.deepdanbooru: - options_templates.update(options_section(('deepbooru-params', "DeepBooru parameters"), { - "deepbooru_sort_alpha": OptionInfo(True, "Sort Alphabetical", gr.Checkbox), - 'deepbooru_threshold': OptionInfo(0.5, "Threshold", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - })) - class Options: data = None diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index a96388d6..113cecf1 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -29,7 +29,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ shared.interrogator.load() if process_caption_deepbooru: - deepbooru.create_deepbooru_process(opts.deepbooru_threshold, opts.deepbooru_sort_alpha) + deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha) def save_pic_with_caption(image, index): if process_caption: diff --git a/modules/ui.py b/modules/ui.py index 2891fc8c..fa45edca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -317,7 +317,7 @@ def interrogate(image): def interrogate_deepbooru(image): - prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold) + prompt = get_deepbooru_tags(image) return gr_show(True) if prompt is None else prompt -- cgit v1.2.3 From 65b973ac4e547a325f30a05f852b161421af2041 Mon Sep 17 00:00:00 2001 From: supersteve3d <39339941+supersteve3d@users.noreply.github.com> Date: Wed, 12 Oct 2022 08:21:52 +0800 Subject: Update shared.py Correct typo to "Unload VAE and CLIP from VRAM when training" in settings tab. --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index c1092ff7..46bc740c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -229,7 +229,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('training', "Training"), { - "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"), + "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { -- cgit v1.2.3 From d717eb079cd6b7fa7a4f97c0a10d400bdec753fb Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Tue, 11 Oct 2022 18:02:41 -0700 Subject: Interrogate: add option to include ranks in output Since the UI also allows users to specify ranks, it can be useful to show people what ranks are being returned by interrogate This can also give much better results when feeding the interrogate results back into either img2img or txt2img, especially when trying to generate a specific character or scene for which you have a similar concept image Testing Steps: Launch Webui with command line arg: --deepdanbooru Navigate to img2img tab, use interrogate DeepBooru, verify tags appears as before. Use "Interrogate CLIP", verify prompt appears as before Navigate to Settings tab, enable new option, click "apply settings" Navigate to img2img, Interrogate DeepBooru again, verify that weights appear and are properly formatted. Note that "Interrogate CLIP" prompt is still unchanged In my testing, this change has no effect to "Interrogate CLIP", as it seems to generate a sentence-structured caption, and not a set of tags. (reproduce changes from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2149/commits/6ed4faac46c45ca7353f228aca9b436bbaba7bc7) --- modules/deepbooru.py | 14 +++++++++----- modules/interrogate.py | 7 +++++-- modules/shared.py | 1 + modules/ui.py | 5 ++--- 4 files changed, 17 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 7e3c0618..32d741e2 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -3,7 +3,7 @@ from concurrent.futures import ProcessPoolExecutor from multiprocessing import get_context -def _load_tf_and_return_tags(pil_image, threshold): +def _load_tf_and_return_tags(pil_image, threshold, include_ranks): import deepdanbooru as dd import tensorflow as tf import numpy as np @@ -52,12 +52,16 @@ def _load_tf_and_return_tags(pil_image, threshold): if result_dict[tag] >= threshold: if tag.startswith("rating:"): continue - result_tags_out.append(tag) + tag_formatted = tag.replace('_', ' ').replace(':', ' ') + if include_ranks: + result_tags_out.append(f'({tag_formatted}:{result_dict[tag]})') + else: + result_tags_out.append(tag_formatted) result_tags_print.append(f'{result_dict[tag]} {tag}') print('\n'.join(sorted(result_tags_print, reverse=True))) - return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') + return ', '.join(result_tags_out) def subprocess_init_no_cuda(): @@ -65,9 +69,9 @@ def subprocess_init_no_cuda(): os.environ["CUDA_VISIBLE_DEVICES"] = "-1" -def get_deepbooru_tags(pil_image, threshold=0.5): +def get_deepbooru_tags(pil_image, threshold=0.5, include_ranks=False): context = get_context('spawn') with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor: - f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, ) + f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, include_ranks) ret = f.result() # will rethrow any exceptions return ret \ No newline at end of file diff --git a/modules/interrogate.py b/modules/interrogate.py index 635e266e..af858cc0 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -123,7 +123,7 @@ class InterrogateModels: return caption[0] - def interrogate(self, pil_image): + def interrogate(self, pil_image, include_ranks=False): res = None try: @@ -156,7 +156,10 @@ class InterrogateModels: for name, topn, items in self.categories: matches = self.rank(image_features, items, top_count=topn) for match, score in matches: - res += ", " + match + if include_ranks: + res += ", " + match + else: + res += f", ({match}:{score})" except Exception: print(f"Error interrogating", file=sys.stderr) diff --git a/modules/shared.py b/modules/shared.py index c1092ff7..3e0bfd72 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -251,6 +251,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), + "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), diff --git a/modules/ui.py b/modules/ui.py index 1204eef7..f4dbe247 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -311,13 +311,12 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name): def interrogate(image): - prompt = shared.interrogator.interrogate(image) - + prompt = shared.interrogator.interrogate(image, include_ranks=opts.interrogate_return_ranks) return gr_show(True) if prompt is None else prompt def interrogate_deepbooru(image): - prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold) + prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold, opts.interrogate_return_ranks) return gr_show(True) if prompt is None else prompt -- cgit v1.2.3 From 6ac2ec2b78bc5fabd09cb866dd9a71061d669269 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 07:01:20 +0300 Subject: create dir for hypernetworks --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index c1092ff7..e65e77f8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,6 +86,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False config_filename = cmd_opts.ui_settings_file +os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None -- cgit v1.2.3 From fec2221eeaafb50afd26ba3e109bf6f928011e69 Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Tue, 11 Oct 2022 19:29:38 -0700 Subject: Truncate error text to fix service lockup / stall What: * Update wrap_gradio_call to add a limit to the maximum amount of text output Why: * wrap_gradio_call currently prints out a list of the arguments provided to the failing function. * if that function is save_image, this causes the entire image to be printed to stderr * If the image is large, this can cause the service to lock up while attempting to print all the text * It is easy to generate large images using the x/y plot script * it is easy to encounter image save exceptions, including if the output directory does not exist / cannot be written to, or if the file is too big * The huge amount of log spam is confusing and not particularly helpful --- modules/ui.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 1204eef7..33a49d3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -181,8 +181,15 @@ def wrap_gradio_call(func, extra_outputs=None): try: res = list(func(*args, **kwargs)) except Exception as e: + # When printing out our debug argument list, do not print out more than a MB of text + max_debug_str_len = 131072 # (1024*1024)/8 + print("Error completing request", file=sys.stderr) - print("Arguments:", args, kwargs, file=sys.stderr) + argStr = f"Arguments: {str(args)} {str(kwargs)}" + print(argStr[:max_debug_str_len], file=sys.stderr) + if len(argStr) > max_debug_str_len: + print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) shared.state.job = "" -- cgit v1.2.3 From 336bd8703c7b4d71f2f096f303599925a30b8167 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 09:00:07 +0300 Subject: just add the deepdanbooru settings unconditionally --- modules/shared.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index f150e024..42e99741 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -249,20 +249,15 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -interrogate_option_dictionary = { +options_templates.update(options_section(('interrogate', "Interrogate Options"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), - "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)") -} - -if cmd_opts.deepdanbooru: - interrogate_option_dictionary["interrogate_deepbooru_score_threshold"] = OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}) - interrogate_option_dictionary["deepbooru_sort_alpha"] = OptionInfo(True, "Interrogate: deepbooru sort alphabetically", gr.Checkbox) - -options_templates.update(options_section(('interrogate', "Interrogate Options"), interrogate_option_dictionary)) + "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), + "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), +})) options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), -- cgit v1.2.3 From 57e03cdd244eee4e33ccab7554b3594563a3d0cd Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 12 Oct 2022 00:54:24 -0400 Subject: Ensure the directory exists before saving to it The directory for the images saved with the Save button may still not exist, so it needs to be created prior to opening the log.csv file. --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 00bf09ae..cd67b84b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -131,6 +131,8 @@ def save_files(js_data, images, do_make_zip, index): images = [images[index]] start_index = index + os.makedirs(opts.outdir_save, exist_ok=True) + with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: at_start = file.tell() == 0 writer = csv.writer(file) -- cgit v1.2.3 From 2d006ce16cd95d587533656c3ac4991495e96f23 Mon Sep 17 00:00:00 2001 From: Milly Date: Mon, 10 Oct 2022 00:56:36 +0900 Subject: xy_grid: Find hypernetwork by closest name --- modules/hypernetworks/hypernetwork.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 470659df..8f2192e2 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -120,6 +120,17 @@ def load_hypernetwork(filename): shared.loaded_hypernetwork = None +def find_closest_hypernetwork_name(search: str): + if not search: + return None + search = search.lower() + applicable = [name for name in shared.hypernetworks if search in name.lower()] + if not applicable: + return None + applicable = sorted(applicable, key=lambda name: len(name)) + return applicable[0] + + def apply_hypernetwork(hypernetwork, context, layer=None): hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) -- cgit v1.2.3 From ee015a1af66a94a75c914659fa0d321e702a0a87 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 11:05:57 +0300 Subject: change textual inversion tab to train remake train interface to use tabs --- modules/hypernetworks/hypernetwork.py | 2 +- modules/ui.py | 22 +++++++++------------- 2 files changed, 10 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8f2192e2..8314450a 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -175,7 +175,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): - assert hypernetwork_name, 'embedding not selected' + assert hypernetwork_name, 'hypernetwork not selected' path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() diff --git a/modules/ui.py b/modules/ui.py index 4bfdd275..86a2da6c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1035,14 +1035,14 @@ def create_ui(wrap_gradio_gpu_call): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() - with gr.Blocks() as textual_inversion_interface: + with gr.Blocks() as train_interface: with gr.Row().style(equal_height=False): - with gr.Column(): - with gr.Group(): - gr.HTML(value="

See wiki for detailed explanation.

") + gr.HTML(value="

See wiki for detailed explanation.

") - gr.HTML(value="

Create a new embedding

") + with gr.Row().style(equal_height=False): + with gr.Tabs(elem_id="train_tabs"): + with gr.Tab(label="Create embedding"): new_embedding_name = gr.Textbox(label="Name") initialization_text = gr.Textbox(label="Initialization text", value="*") nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) @@ -1054,9 +1054,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create embedding", variant='primary') - with gr.Group(): - gr.HTML(value="

Create a new hypernetwork

") - + with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) @@ -1067,9 +1065,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') - with gr.Group(): - gr.HTML(value="

Preprocess images

") - + with gr.Tab(label="Preprocess images"): process_src = gr.Textbox(label='Source directory') process_dst = gr.Textbox(label='Destination directory') process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) @@ -1091,7 +1087,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): run_preprocess = gr.Button(value="Preprocess", variant='primary') - with gr.Group(): + with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) @@ -1388,7 +1384,7 @@ Requested path was: {f} (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (textual_inversion_interface, "Textual inversion", "ti"), + (train_interface, "Train", "ti"), (settings_interface, "Settings", "settings"), ] -- cgit v1.2.3 From 80f3cf2bb2ce3f00d801cae2c3a8c20a8d4167d8 Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Tue, 11 Oct 2022 19:48:53 +0100 Subject: Account when lines are mismatched --- modules/sd_hijack.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ac70f876..2753d4fa 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -321,7 +321,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): fixes.append(fix[1]) self.hijack.fixes.append(fixes) - z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers]) + tokens = [] + multipliers = [] + for i in range(len(remade_batch_tokens)): + if len(remade_batch_tokens[i]) > 0: + tokens.append(remade_batch_tokens[i][:75]) + multipliers.append(batch_multipliers[i][:75]) + else: + tokens.append([self.wrapped.tokenizer.eos_token_id] * 75) + multipliers.append([1.0] * 75) + + z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) remade_batch_tokens = rem_tokens -- cgit v1.2.3 From 429442f4a6aab7301efb89d27bef524fe827e81a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 13:38:03 +0300 Subject: fix iterator bug for #2295 --- modules/sd_hijack.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2753d4fa..c81722a0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -323,10 +323,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): tokens = [] multipliers = [] - for i in range(len(remade_batch_tokens)): - if len(remade_batch_tokens[i]) > 0: - tokens.append(remade_batch_tokens[i][:75]) - multipliers.append(batch_multipliers[i][:75]) + for j in range(len(remade_batch_tokens)): + if len(remade_batch_tokens[j]) > 0: + tokens.append(remade_batch_tokens[j][:75]) + multipliers.append(batch_multipliers[j][:75]) else: tokens.append([self.wrapped.tokenizer.eos_token_id] * 75) multipliers.append([1.0] * 75) -- cgit v1.2.3 From 50be33e953be93c40814262c6dbce36e66004528 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 12 Oct 2022 13:13:25 +0100 Subject: formatting --- modules/textual_inversion/image_embedding.py | 170 ++++++++++++++------------- 1 file changed, 91 insertions(+), 79 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index 1224fb42..898ce3b3 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -2,122 +2,134 @@ import base64 import json import numpy as np import zlib -from PIL import Image,PngImagePlugin,ImageDraw,ImageFont +from PIL import Image, PngImagePlugin, ImageDraw, ImageFont from fonts.ttf import Roboto import torch + class EmbeddingEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, torch.Tensor): - return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()} + return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()} return json.JSONEncoder.default(self, obj) + class EmbeddingDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + def object_hook(self, d): if 'TORCHTENSOR' in d: return torch.from_numpy(np.array(d['TORCHTENSOR'])) return d + def embedding_to_b64(data): - d = json.dumps(data,cls=EmbeddingEncoder) + d = json.dumps(data, cls=EmbeddingEncoder) return base64.b64encode(d.encode()) + def embedding_from_b64(data): d = base64.b64decode(data) - return json.loads(d,cls=EmbeddingDecoder) + return json.loads(d, cls=EmbeddingDecoder) + def lcg(m=2**32, a=1664525, c=1013904223, seed=0): while True: seed = (a * seed + c) % m - yield seed%255 + yield seed % 255 + def xor_block(block): g = lcg() randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape) - return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F) + return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F) -def style_block(block,sequence): - im = Image.new('RGB',(block.shape[1],block.shape[0])) + +def style_block(block, sequence): + im = Image.new('RGB', (block.shape[1], block.shape[0])) draw = ImageDraw.Draw(im) - i=0 - for x in range(-6,im.size[0],8): - for yi,y in enumerate(range(-6,im.size[1],8)): - offset=0 - if yi%2==0: - offset=4 - shade = sequence[i%len(sequence)] - i+=1 - draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) ) + i = 0 + for x in range(-6, im.size[0], 8): + for yi, y in enumerate(range(-6, im.size[1], 8)): + offset = 0 + if yi % 2 == 0: + offset = 4 + shade = sequence[i % len(sequence)] + i += 1 + draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade)) fg = np.array(im).astype(np.uint8) & 0xF0 return block ^ fg -def insert_image_data_embed(image,data): + +def insert_image_data_embed(image, data): d = 3 - data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9) - data_np_ = np.frombuffer(data_compressed,np.uint8).copy() + data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9) + data_np_ = np.frombuffer(data_compressed, np.uint8).copy() data_np_high = data_np_ >> 4 - data_np_low = data_np_ & 0x0F - + data_np_low = data_np_ & 0x0F + h = image.size[1] - next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0]%h)) - next_size = next_size + ((h*d)-(next_size%(h*d))) + next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h)) + next_size = next_size + ((h*d)-(next_size % (h*d))) data_np_low.resize(next_size) - data_np_low = data_np_low.reshape((h,-1,d)) + data_np_low = data_np_low.reshape((h, -1, d)) data_np_high.resize(next_size) - data_np_high = data_np_high.reshape((h,-1,d)) + data_np_high = data_np_high.reshape((h, -1, d)) edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8) - data_np_low = style_block(data_np_low,sequence=edge_style) - data_np_low = xor_block(data_np_low) - data_np_high = style_block(data_np_high,sequence=edge_style[::-1]) - data_np_high = xor_block(data_np_high) + data_np_low = style_block(data_np_low, sequence=edge_style) + data_np_low = xor_block(data_np_low) + data_np_high = style_block(data_np_high, sequence=edge_style[::-1]) + data_np_high = xor_block(data_np_high) - im_low = Image.fromarray(data_np_low,mode='RGB') - im_high = Image.fromarray(data_np_high,mode='RGB') + im_low = Image.fromarray(data_np_low, mode='RGB') + im_high = Image.fromarray(data_np_high, mode='RGB') - background = Image.new('RGB',(image.size[0]+im_low.size[0]+im_high.size[0]+2,image.size[1]),(0,0,0)) - background.paste(im_low,(0,0)) - background.paste(image,(im_low.size[0]+1,0)) - background.paste(im_high,(im_low.size[0]+1+image.size[0]+1,0)) + background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0)) + background.paste(im_low, (0, 0)) + background.paste(image, (im_low.size[0]+1, 0)) + background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0)) return background -def crop_black(img,tol=0): - mask = (img>tol).all(2) - mask0,mask1 = mask.any(0),mask.any(1) - col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax() - row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax() - return img[row_start:row_end,col_start:col_end] + +def crop_black(img, tol=0): + mask = (img > tol).all(2) + mask0, mask1 = mask.any(0), mask.any(1) + col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax() + row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax() + return img[row_start:row_end, col_start:col_end] + def extract_image_data_embed(image): - d=3 - outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F - black_cols = np.where( np.sum(outarr, axis=(0,2))==0) + d = 3 + outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F + black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0) if black_cols[0].shape[0] < 2: print('No Image data blocks found.') return None - data_block_lower = outarr[:,:black_cols[0].min(),:].astype(np.uint8) - data_block_upper = outarr[:,black_cols[0].max()+1:,:].astype(np.uint8) + data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8) + data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8) data_block_lower = xor_block(data_block_lower) data_block_upper = xor_block(data_block_upper) - + data_block = (data_block_upper << 4) | (data_block_lower) data_block = data_block.flatten().tobytes() data = zlib.decompress(data_block) - return json.loads(data,cls=EmbeddingDecoder) + return json.loads(data, cls=EmbeddingDecoder) + -def caption_image_overlay(srcimage,title,footerLeft,footerMid,footerRight,textfont=None): +def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None): from math import cos image = srcimage.copy() @@ -130,11 +142,11 @@ def caption_image_overlay(srcimage,title,footerLeft,footerMid,footerRight,textfo textfont = Roboto factor = 1.5 - gradient = Image.new('RGBA', (1,image.size[1]), color=(0,0,0,0)) + gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0)) for y in range(image.size[1]): mag = 1-cos(y/image.size[1]*factor) - mag = max(mag,1-cos((image.size[1]-y)/image.size[1]*factor*1.1)) - gradient.putpixel((0, y), (0,0,0,int(mag*255))) + mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1)) + gradient.putpixel((0, y), (0, 0, 0, int(mag*255))) image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) draw = ImageDraw.Draw(image) @@ -142,41 +154,41 @@ def caption_image_overlay(srcimage,title,footerLeft,footerMid,footerRight,textfo font = ImageFont.truetype(textfont, fontsize) padding = 10 - _,_,w, h = draw.textbbox((0,0),title,font=font) - fontsize = min( int(fontsize * (((image.size[0]*0.75)-(padding*4))/w) ), 72) + _, _, w, h = draw.textbbox((0, 0), title, font=font) + fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72) font = ImageFont.truetype(textfont, fontsize) - _,_,w,h = draw.textbbox((0,0),title,font=font) - draw.text((padding,padding), title, anchor='lt', font=font, fill=(255,255,255,230)) + _, _, w, h = draw.textbbox((0, 0), title, font=font) + draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230)) - _,_,w, h = draw.textbbox((0,0),footerLeft,font=font) - fontsize_left = min( int(fontsize * (((image.size[0]/3)-(padding))/w) ), 72) - _,_,w, h = draw.textbbox((0,0),footerMid,font=font) - fontsize_mid = min( int(fontsize * (((image.size[0]/3)-(padding))/w) ), 72) - _,_,w, h = draw.textbbox((0,0),footerRight,font=font) - fontsize_right = min( int(fontsize * (((image.size[0]/3)-(padding))/w) ), 72) + _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font) + fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) + _, _, w, h = draw.textbbox((0, 0), footerMid, font=font) + fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) + _, _, w, h = draw.textbbox((0, 0), footerRight, font=font) + fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72) - font = ImageFont.truetype(textfont, min(fontsize_left,fontsize_mid,fontsize_right)) + font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right)) - draw.text((padding,image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255,255,255,230)) - draw.text((image.size[0]/2,image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255,255,255,230)) - draw.text((image.size[0]-padding,image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255,255,255,230)) + draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230)) + draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230)) + draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230)) return image + if __name__ == '__main__': testEmbed = Image.open('test_embedding.png') - data = extract_image_data_embed(testEmbed) assert data is not None data = embedding_from_b64(testEmbed.text['sd-ti-embedding']) assert data is not None - - image = Image.new('RGBA',(512,512),(255,255,200,255)) + + image = Image.new('RGBA', (512, 512), (255, 255, 200, 255)) cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight') - test_embed = {'string_to_param':{'*':torch.from_numpy(np.random.random((2, 4096)))}} + test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}} embedded_image = insert_image_data_embed(cap_image, test_embed) @@ -191,16 +203,16 @@ if __name__ == '__main__': g = lcg() shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist() - reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177, - 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179, - 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193, - 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28, - 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0, - 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185, - 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82, + reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177, + 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179, + 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193, + 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28, + 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0, + 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185, + 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82, 204, 86, 73, 222, 44, 198, 118, 240, 97] - assert shared_random == reference_random + assert shared_random == reference_random hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist()) -- cgit v1.2.3 From 10a2de644f8ea4cfade88e85d768da3480f4c9f0 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 12 Oct 2022 13:15:35 +0100 Subject: formatting --- modules/textual_inversion/textual_inversion.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 485ef46c..b072d745 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,14 +7,14 @@ import tqdm import html import datetime -from PIL import Image,PngImagePlugin +from PIL import Image, PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset -from modules.textual_inversion.image_embedding import (embedding_to_b64,embedding_from_b64, - insert_image_data_embed,extract_image_data_embed, - caption_image_overlay ) +from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, + insert_image_data_embed, extract_image_data_embed, + caption_image_overlay) class Embedding: def __init__(self, vec, name, step=None): @@ -90,10 +90,10 @@ class EmbeddingDatabase: embed_image = Image.open(path) if 'sd-ti-embedding' in embed_image.text: data = embedding_from_b64(embed_image.text['sd-ti-embedding']) - name = data.get('name',name) + name = data.get('name', name) else: data = extract_image_data_embed(embed_image) - name = data.get('name',name) + name = data.get('name', name) else: data = torch.load(path, map_location="cpu") @@ -278,24 +278,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.current_image = image if save_image_with_stored_embedding and os.path.exists(last_saved_file): - + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) info.add_text("sd-ti-embedding", embedding_to_b64(data)) - title = "<{}>".format(data.get('name','???')) + title = "<{}>".format(data.get('name', '???')) checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name footer_mid = '[{}]'.format(checkpoint.hash) footer_right = '{}'.format(embedding.step) - captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right) - captioned_image = insert_image_data_embed(captioned_image,data) + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) - + image.save(last_saved_image) last_saved_image += f", prompt: {preview_text}" -- cgit v1.2.3 From e05573e1adc1cde1e3bd7eb651a1ab27c446b3d5 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Wed, 12 Oct 2022 20:47:55 +0800 Subject: images history improvement --- modules/images_history.py | 67 ++++++++++++++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 24 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 23f55b30..77f692fe 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -1,15 +1,29 @@ import os -def get_recent_images(dir_name, page_index, step, image_index): - #print(image_index) +import shutil +def get_recent_images(dir_name, page_index, step, image_index, tabname): + print(f"renew page {page_index}") page_index = int(page_index) f_list = os.listdir(dir_name) file_list = [] for file in f_list: if file[-4:] == ".txt": continue - file_list.append(file) + #subdirectories + if file[-10:].rfind(".") < 0: + sub_dir = os.path.join(dir_name, file) + if os.path.isfile(sub_dir): + continue + sub_file_list = os.listdir(sub_dir) + for sub_file in sub_file_list: + if sub_file[-4:] == ".txt": + continue + if os.path.isfile(os.path.join(sub_dir, sub_file) ): + file_list.append(os.path.join(file, sub_file)) + continue + file_list.append(file) + file_list = sorted(file_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) - num = 48 + num = 48 if tabname != "extras" else 12 max_page_index = len(file_list) // num + 1 page_index = max_page_index if page_index == -1 else page_index + step page_index = 1 if page_index < 1 else page_index @@ -26,26 +40,28 @@ def get_recent_images(dir_name, page_index, step, image_index): hide_image = os.path.join(dir_name, current_file) return [os.path.join(dir_name, file) for file in file_list], page_index, file_list, current_file, hide_image def first_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, 1, 0, image_index) + return get_recent_images(dir_name, 1, 0, image_index, tabname) def end_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, -1, 0, image_index) + return get_recent_images(dir_name, -1, 0, image_index, tabname) def prev_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, -1, image_index) + return get_recent_images(dir_name, page_index, -1, image_index, tabname) def next_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, 1, image_index) + return get_recent_images(dir_name, page_index, 1, image_index, tabname) def page_index_change(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, 0, image_index) + return get_recent_images(dir_name, page_index, 0, image_index, tabname) def show_image_info(num, image_path, filenames): - #print("set img",num) + print(f"select image {num}") file = filenames[int(num)] return file, num, os.path.join(image_path, file) def delete_image(tabname, dir_name, name, page_index, filenames, image_index): - #print("filename", name) path = os.path.join(dir_name, name) - if os.path.exists(path): + if os.path.exists(path): print(f"Delete file {path}") - os.remove(path) + os.remove(path) + txt_file = os.path.splitext(path)[0] + ".txt" + if os.path.exists(txt_file): + os.remove(txt_file) new_file_list = [] for f in filenames: if f == name: @@ -64,25 +80,26 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): elif tabname == "extras": dir_name = opts.outdir_extras_samples with gr.Row(): - renew_page = gr.Button('Renew', elem_id=tabname + "_images_history_renew_page") - first_page = gr.Button('First', elem_id=tabname + "_images_history_first_page") - prev_page = gr.Button('Prev') + renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page") + first_page = gr.Button('First Page') + prev_page = gr.Button('Prev Page') page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next', elem_id=tabname + "_images_history_next_page") - end_page = gr.Button('End') + next_page = gr.Button('Next Page') + end_page = gr.Button('End Page') with gr.Row(elem_id=tabname + "_images_history"): with gr.Row(): - with gr.Column(): - history_gallery = gr.Gallery(show_label=False).style(grid=6) + with gr.Column(scale=2): + history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) + delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") with gr.Column(): with gr.Row(): - delete = gr.Button('Delete') + #pnginfo = gr.Button('PNG info') pnginfo_send_to_txt2img = gr.Button('Send to txt2img') pnginfo_send_to_img2img = gr.Button('Send to img2img') with gr.Row(): with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info") - img_file_name = gr.Textbox(label="File Name") + img_file_info = gr.Textbox(label="Generate Info", interactive=False) + img_file_name = gr.Textbox(label="File Name", interactive=False) with gr.Row(): # hiden items img_path = gr.Textbox(dir_name, visible=False) @@ -90,7 +107,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): image_index = gr.Textbox(value=-1, visible=False) set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) filenames = gr.State() - hide_image = gr.Image(visible=False, type="pil") + hide_image = gr.Image(type="pil", visible=False) info1 = gr.Textbox(visible=False) info2 = gr.Textbox(visible=False) @@ -111,6 +128,8 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hide_image]) delete.click(delete_image,_js="images_history_delete", inputs=[tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[page_index, filenames]) hide_image.change(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) + hide_image.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) + #pnginfo.click(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') -- cgit v1.2.3 From a1a94b8b5f342f467aecc53b21b80ed0227ee76a Mon Sep 17 00:00:00 2001 From: yfszzx Date: Thu, 13 Oct 2022 00:19:34 +0800 Subject: images history improvement --- modules/images_history.py | 7 ++++--- modules/ui.py | 40 +++++++++++++++++++--------------------- 2 files changed, 23 insertions(+), 24 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 2bc4b7ee..1bca0ad9 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -61,7 +61,7 @@ def delete_image(tabname, dir_name, name, page_index, filenames, image_index): os.remove(path) txt_file = os.path.splitext(path)[0] + ".txt" if os.path.exists(txt_file): - os.remove(txt_file) + os.remove(txt_file) new_file_list = [] for f in filenames: if f == name: @@ -88,7 +88,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): end_page = gr.Button('End Page') with gr.Row(elem_id=tabname + "_images_history"): with gr.Row(): - with gr.Column(scale=2): + with gr.Column(scale=2): history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") with gr.Column(): @@ -126,9 +126,10 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): #other funcitons set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hide_image]) + img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) delete.click(delete_image,_js="images_history_delete", inputs=[tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[page_index, filenames]) hide_image.change(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) - hide_image.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) + #pnginfo.click(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') diff --git a/modules/ui.py b/modules/ui.py index 94297ba6..8cd12b51 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -39,7 +39,7 @@ import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui -import modules.hypernetwork.ui +import modules.hypernetworks.ui import modules.images_history as img_his # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI @@ -554,6 +554,7 @@ def create_ui(wrap_gradio_gpu_call): custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) with gr.Column(variant='panel'): + with gr.Group(): txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4) @@ -573,9 +574,9 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) - with gr.Group(): - html_info = gr.HTML() - generation_info = gr.Textbox(visible=False) + with gr.Group(): + html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -669,7 +670,6 @@ def create_ui(wrap_gradio_gpu_call): ] modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) - with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) @@ -762,10 +762,10 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) - with gr.Group(): - html_info = gr.HTML() - generation_info = gr.Textbox(visible=False) - + with gr.Group(): + html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -1016,6 +1016,13 @@ def create_ui(wrap_gradio_gpu_call): inputs=[image], outputs=[html, generation_info, html2], ) + #images history + images_history_switch_dict = { + "fn":modules.generation_parameters_copypaste.connect_paste, + "t2i":txt2img_paste_fields, + "i2i":img2img_paste_fields + } + images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): @@ -1285,16 +1292,7 @@ Requested path was: {f} opts.save(shared.config_filename) - return f'{changed} settings changed.', opts.dumpjson() - - #images history - images_history_switch_dict = { - "fn":modules.generation_parameters_copypaste.connect_paste, - "t2i":txt2img_paste_fields, - "i2i":img2img_paste_fields - } - images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) - + return f'{changed} settings changed.', opts.dumpjson() def run_settings_single(value, key): if not opts.same_type(value, opts.data_labels[key].default): @@ -1393,11 +1391,10 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (images_history, "History", "images_history"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), (settings_interface, "Settings", "settings"), - ] with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: @@ -1616,3 +1613,4 @@ if 'gradio_routes_templates_response' not in globals(): gradio_routes_templates_response = gradio.routes.templates.TemplateResponse gradio.routes.templates.TemplateResponse = template_response + -- cgit v1.2.3 From a2aa2a68bc7868320b502a78765be597e507ce45 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Thu, 13 Oct 2022 00:21:16 +0800 Subject: images history improvement --- modules/images_history.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 1bca0ad9..6408973c 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -1,7 +1,7 @@ import os import shutil def get_recent_images(dir_name, page_index, step, image_index, tabname): - print(f"renew page {page_index}") + #print(f"renew page {page_index}") page_index = int(page_index) f_list = os.listdir(dir_name) file_list = [] @@ -51,7 +51,7 @@ def page_index_change(dir_name, page_index, image_index, tabname): return get_recent_images(dir_name, page_index, 0, image_index, tabname) def show_image_info(num, image_path, filenames): - print(f"select image {num}") + #print(f"select image {num}") file = filenames[int(num)] return file, num, os.path.join(image_path, file) def delete_image(tabname, dir_name, name, page_index, filenames, image_index): -- cgit v1.2.3 From c3c8eef9fd5a0c8b26319e32ca4a19b56204e6df Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 20:49:47 +0300 Subject: train: change filename processing to be more simple and configurable train: make it possible to make text files with prompts train: rework scheduler so that there's less repeating code in textual inversion and hypernets train: move epochs setting to options --- modules/hypernetworks/hypernetwork.py | 40 +++++++++------------- modules/shared.py | 3 ++ modules/textual_inversion/dataset.py | 47 +++++++++++++++++++------- modules/textual_inversion/learn_schedule.py | 37 +++++++++++++++++++- modules/textual_inversion/textual_inversion.py | 35 +++++++------------ modules/ui.py | 2 -- 6 files changed, 102 insertions(+), 62 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8314450a..b6c06d49 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -14,7 +14,7 @@ import torch from torch import einsum from einops import rearrange, repeat import modules.textual_inversion.dataset -from modules.textual_inversion.learn_schedule import LearnSchedule +from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): @@ -223,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps: return hypernetwork, filename - schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) - (learn_rate, end_step) = next(schedules) - print(f'Training at rate of {learn_rate} until step {end_step}') - - optimizer = torch.optim.AdamW(weights, lr=learn_rate) + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text, cond) in pbar: + for i, entry in pbar: hypernetwork.step = i + ititial_step - if hypernetwork.step > end_step: - try: - (learn_rate, end_step) = next(schedules) - except Exception: - break - tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') - for pg in optimizer.param_groups: - pg['lr'] = learn_rate + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break if shared.state.interrupted: break with torch.autocast("cuda"): - cond = cond.to(devices.device) - x = x.to(devices.device) + cond = entry.cond.to(devices.device) + x = entry.latent.to(devices.device) loss = shared.sd_model(x.unsqueeze(0), cond)[0] del x del cond @@ -267,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - preview_text = text if preview_image_prompt == "" else preview_image_prompt + preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt optimizer.zero_grad() shared.sd_model.cond_stage_model.to(devices.device) @@ -282,16 +274,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ) processed = processing.process_images(p) - image = processed.images[0] + image = processed.images[0] if len(processed.images)>0 else None if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - shared.state.current_image = image - image.save(last_saved_image) - - last_saved_image += f", prompt: {preview_text}" + if image is not None: + shared.state.current_image = image + image.save(last_saved_image) + last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step @@ -299,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,

Loss: {losses.mean():.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/shared.py b/modules/shared.py index 42e99741..e64e69fc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -231,6 +231,9 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), + "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), + "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), + "training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index f61f40d3..67e90afe 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -11,11 +11,21 @@ import tqdm from modules import devices, shared import re -re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") +re_numbers_at_start = re.compile(r"^[-\d]+\s*") + + +class DatasetEntry: + def __init__(self, filename=None, latent=None, filename_text=None): + self.filename = filename + self.latent = latent + self.filename_text = filename_text + self.cond = None + self.cond_text = None class PersonalizedBase(Dataset): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): + re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None self.placeholder_token = placeholder_token @@ -42,9 +52,18 @@ class PersonalizedBase(Dataset): except Exception: continue + text_filename = os.path.splitext(path)[0] + ".txt" filename = os.path.basename(path) - filename_tokens = os.path.splitext(filename)[0] - filename_tokens = re_tag.findall(filename_tokens) + + if os.path.exists(text_filename): + with open(text_filename, "r", encoding="utf8") as file: + filename_text = file.read() + else: + filename_text = os.path.splitext(filename)[0] + filename_text = re.sub(re_numbers_at_start, '', filename_text) + if re_word: + tokens = re_word.findall(filename_text) + filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) npimage = np.array(image).astype(np.uint8) npimage = (npimage / 127.5 - 1.0).astype(np.float32) @@ -55,13 +74,13 @@ class PersonalizedBase(Dataset): init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = init_latent.to(devices.cpu) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) + if include_cond: - text = self.create_text(filename_tokens) - cond = cond_model([text]).to(devices.cpu) - else: - cond = None + entry.cond_text = self.create_text(filename_text) + entry.cond = cond_model([entry.cond_text]).to(devices.cpu) - self.dataset.append((init_latent, filename_tokens, cond)) + self.dataset.append(entry) self.length = len(self.dataset) * repeats @@ -72,10 +91,10 @@ class PersonalizedBase(Dataset): def shuffle(self): self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] - def create_text(self, filename_tokens): + def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) - text = text.replace("[filewords]", ' '.join(filename_tokens)) + text = text.replace("[filewords]", filename_text) return text def __len__(self): @@ -86,7 +105,9 @@ class PersonalizedBase(Dataset): self.shuffle() index = self.indexes[i % len(self.indexes)] - x, filename_tokens, cond = self.dataset[index] + entry = self.dataset[index] + + if entry.cond is None: + entry.cond_text = self.create_text(entry.filename_text) - text = self.create_text(filename_tokens) - return x, text, cond + return entry diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index db720271..2062726a 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -1,6 +1,12 @@ +import tqdm -class LearnSchedule: + +class LearnScheduleIterator: def __init__(self, learn_rate, max_steps, cur_step=0): + """ + specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000 + """ + pairs = learn_rate.split(',') self.rates = [] self.it = 0 @@ -32,3 +38,32 @@ class LearnSchedule: return self.rates[self.it - 1] else: raise StopIteration + + +class LearnRateScheduler: + def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True): + self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step) + (self.learn_rate, self.end_step) = next(self.schedules) + self.verbose = verbose + + if self.verbose: + print(f'Training at rate of {self.learn_rate} until step {self.end_step}') + + self.finished = False + + def apply(self, optimizer, step_number): + if step_number <= self.end_step: + return + + try: + (self.learn_rate, self.end_step) = next(self.schedules) + except Exception: + self.finished = True + return + + if self.verbose: + tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}') + + for pg in optimizer.param_groups: + pg['lr'] = self.learn_rate + diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c5153e4a..fa0e33a2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -11,7 +11,7 @@ from PIL import Image, PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset -from modules.textual_inversion.learn_schedule import LearnSchedule +from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, @@ -172,8 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn - -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -205,7 +204,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -221,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if ititial_step > steps: return embedding, filename - schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) - (learn_rate, end_step) = next(schedules) - print(f'Training at rate of {learn_rate} until step {end_step}') - - optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, (x, text, _) in pbar: + for i, entry in pbar: embedding.step = i + ititial_step - if embedding.step > end_step: - try: - (learn_rate, end_step) = next(schedules) - except: - break - tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') - for pg in optimizer.param_groups: - pg['lr'] = learn_rate + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break if shared.state.interrupted: break with torch.autocast("cuda"): - c = cond_model([text]) + c = cond_model([entry.cond_text]) - x = x.to(devices.device) + x = entry.latent.to(devices.device) loss = shared.sd_model(x.unsqueeze(0), c)[0] del x @@ -268,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') - preview_text = text if preview_image_prompt == "" else preview_image_prompt + preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -314,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini

Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/ui.py b/modules/ui.py index 2b332267..c42535c8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1098,7 +1098,6 @@ def create_ui(wrap_gradio_gpu_call): training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) - num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) @@ -1176,7 +1175,6 @@ def create_ui(wrap_gradio_gpu_call): training_width, training_height, steps, - num_repeats, create_image_every, save_embedding_every, template_file, -- cgit v1.2.3 From 698d303b04e293635bfb49c525409f3bcf671dce Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 21:55:43 +0300 Subject: deepbooru: added option to use spaces or underscores deepbooru: added option to quote (\) in tags deepbooru/BLIP: write caption to file instead of image filename deepbooru/BLIP: now possible to use both for captions deepbooru: process is stopped even if an exception occurs --- modules/deepbooru.py | 65 ++++++++++++++++++----- modules/shared.py | 2 + modules/textual_inversion/preprocess.py | 92 ++++++++++++++------------------- modules/ui.py | 7 +-- 4 files changed, 95 insertions(+), 71 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 29529949..419e6a9c 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -2,33 +2,44 @@ import os.path from concurrent.futures import ProcessPoolExecutor import multiprocessing import time +import re + +re_special = re.compile(r'([\\()])') def get_deepbooru_tags(pil_image): """ This method is for running only one image at a time for simple use. Used to the img2img interrogate. """ from modules import shared # prevents circular reference - create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha) - shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process_queue.put(pil_image) - while shared.deepbooru_process_return["value"] == -1: - time.sleep(0.2) - tags = shared.deepbooru_process_return["value"] - release_process() - return tags + try: + create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts()) + return get_tags_from_process(pil_image) + finally: + release_process() + + +def create_deepbooru_opts(): + from modules import shared -def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort): + return { + "use_spaces": shared.opts.deepbooru_use_spaces, + "use_escape": shared.opts.deepbooru_escape, + "alpha_sort": shared.opts.deepbooru_sort_alpha, + } + + +def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts): model, tags = get_deepbooru_tags_model() while True: # while process is running, keep monitoring queue for new image pil_image = queue.get() if pil_image == "QUIT": break else: - deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort) + deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts) -def create_deepbooru_process(threshold, alpha_sort): +def create_deepbooru_process(threshold, deepbooru_opts): """ Creates deepbooru process. A queue is created to send images into the process. This enables multiple images to be processed in a row without reloading the model or creating a new process. To return the data, a shared @@ -41,10 +52,23 @@ def create_deepbooru_process(threshold, alpha_sort): shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort)) + shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts)) shared.deepbooru_process.start() +def get_tags_from_process(image): + from modules import shared + + shared.deepbooru_process_return["value"] = -1 + shared.deepbooru_process_queue.put(image) + while shared.deepbooru_process_return["value"] == -1: + time.sleep(0.2) + caption = shared.deepbooru_process_return["value"] + shared.deepbooru_process_return["value"] = -1 + + return caption + + def release_process(): """ Stops the deepbooru process to return used memory @@ -81,10 +105,15 @@ def get_deepbooru_tags_model(): return model, tags -def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort): +def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts): import deepdanbooru as dd import tensorflow as tf import numpy as np + + alpha_sort = deepbooru_opts['alpha_sort'] + use_spaces = deepbooru_opts['use_spaces'] + use_escape = deepbooru_opts['use_escape'] + width = model.input_shape[2] height = model.input_shape[1] image = np.array(pil_image) @@ -129,4 +158,12 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort) print('\n'.join(sorted(result_tags_print, reverse=True))) - return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') + tags_text = ', '.join(result_tags_out) + + if use_spaces: + tags_text = tags_text.replace('_', ' ') + + if use_escape: + tags_text = re.sub(re_special, r'\\\1', tags_text) + + return tags_text.replace(':', ' ') diff --git a/modules/shared.py b/modules/shared.py index e64e69fc..78b73aae 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -260,6 +260,8 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), + "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), + "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"), })) options_templates.update(options_section(('ui', "User interface"), { diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 113cecf1..3047bede 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -10,7 +10,28 @@ from modules.shared import opts, cmd_opts if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru + def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): + try: + if process_caption: + shared.interrogator.load() + + if process_caption_deepbooru: + deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts()) + + preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + + finally: + + if process_caption: + shared.interrogator.send_blip_to_ram() + + if process_caption_deepbooru: + deepbooru.release_process() + + + +def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): width = process_width height = process_height src = os.path.abspath(process_src) @@ -25,30 +46,28 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ shared.state.textinfo = "Preprocessing..." shared.state.job_count = len(files) - if process_caption: - shared.interrogator.load() - - if process_caption_deepbooru: - deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha) - def save_pic_with_caption(image, index): + caption = "" + if process_caption: - caption = "-" + shared.interrogator.generate_caption(image) - caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png") - elif process_caption_deepbooru: - shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process_queue.put(image) - while shared.deepbooru_process_return["value"] == -1: - time.sleep(0.2) - caption = "-" + shared.deepbooru_process_return["value"] - caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png") - shared.deepbooru_process_return["value"] = -1 - else: - caption = filename - caption = os.path.splitext(caption)[0] - caption = os.path.basename(caption) + caption += shared.interrogator.generate_caption(image) + + if process_caption_deepbooru: + if len(caption) > 0: + caption += ", " + caption += deepbooru.get_tags_from_process(image) + + filename_part = filename + filename_part = os.path.splitext(filename_part)[0] + filename_part = os.path.basename(filename_part) + + basename = f"{index:05}-{subindex[0]}-{filename_part}" + image.save(os.path.join(dst, f"{basename}.png")) + + if len(caption) > 0: + with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: + file.write(caption) - image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png")) subindex[0] += 1 def save_pic(image, index): @@ -93,34 +112,3 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ save_pic(img, index) shared.state.nextjob() - - if process_caption: - shared.interrogator.send_blip_to_ram() - - if process_caption_deepbooru: - deepbooru.release_process() - - -def sanitize_caption(base_path, original_caption, suffix): - operating_system = platform.system().lower() - if (operating_system == "windows"): - invalid_path_characters = "\\/:*?\"<>|" - max_path_length = 259 - else: - invalid_path_characters = "/" #linux/macos - max_path_length = 1023 - caption = original_caption - for invalid_character in invalid_path_characters: - caption = caption.replace(invalid_character, "") - fixed_path_length = len(base_path) + len(suffix) - if fixed_path_length + len(caption) <= max_path_length: - return caption - caption_tokens = caption.split() - new_caption = "" - for token in caption_tokens: - last_caption = new_caption - new_caption = new_caption + token + " " - if (len(new_caption) + fixed_path_length - 1 > max_path_length): - break - print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr) - return last_caption.strip() diff --git a/modules/ui.py b/modules/ui.py index c42535c8..e07ee0e1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1074,11 +1074,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') process_split = gr.Checkbox(label='Split oversized images into two') - process_caption = gr.Checkbox(label='Use BLIP caption as filename') - if cmd_opts.deepdanbooru: - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename') - else: - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename', visible=False) + process_caption = gr.Checkbox(label='Use BLIP for caption') + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) with gr.Row(): with gr.Column(scale=3): -- cgit v1.2.3 From efefa4862c6c75115d3da9f768348630cc32bdea Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Wed, 12 Oct 2022 13:03:00 -0700 Subject: [1/?] [wip] Reintroduce opts.interrogate_return_ranks looks functionally correct, needs testing Needs particular testing care around whether the colon usage (:) will break anything in whatever new use cases were introduced by https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2143 --- modules/deepbooru.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 419e6a9c..2cbf2cab 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -26,6 +26,7 @@ def create_deepbooru_opts(): "use_spaces": shared.opts.deepbooru_use_spaces, "use_escape": shared.opts.deepbooru_escape, "alpha_sort": shared.opts.deepbooru_sort_alpha, + "include_ranks": shared.opts.interrogate_return_ranks, } @@ -113,6 +114,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o alpha_sort = deepbooru_opts['alpha_sort'] use_spaces = deepbooru_opts['use_spaces'] use_escape = deepbooru_opts['use_escape'] + include_ranks = deepbooru_opts['include_ranks'] width = model.input_shape[2] height = model.input_shape[1] @@ -151,19 +153,20 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o if alpha_sort: sort_ndx = 1 - # sort by reverse by likelihood and normal for alpha + # sort by reverse by likelihood and normal for alpha, and format tag text as requested unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) for weight, tag in unsorted_tags_in_theshold: - result_tags_out.append(tag) + # note: tag_outformat will still have a colon if include_ranks is True + tag_outformat = tag.replace(':', ' ') + if use_spaces: + tag_outformat = tag_outformat.replace('_', ' ') + if use_escape: + tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) + if include_ranks: + use_escape += f":{weight:.3f}" - print('\n'.join(sorted(result_tags_print, reverse=True))) - - tags_text = ', '.join(result_tags_out) + result_tags_out.append(tag_outformat) - if use_spaces: - tags_text = tags_text.replace('_', ' ') - - if use_escape: - tags_text = re.sub(re_special, r'\\\1', tags_text) + print('\n'.join(sorted(result_tags_print, reverse=True))) - return tags_text.replace(':', ' ') + return ', '.join(result_tags_out) -- cgit v1.2.3 From f776254b12361b5bae16f6629bcdcb47b450c48d Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Wed, 12 Oct 2022 13:08:06 -0700 Subject: [2/?] [wip] ignore OPT_INCLUDE_RANKS for training filenames --- modules/deepbooru.py | 3 ++- modules/textual_inversion/preprocess.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 2cbf2cab..fcc05819 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -19,6 +19,7 @@ def get_deepbooru_tags(pil_image): release_process() +OPT_INCLUDE_RANKS = "include_ranks" def create_deepbooru_opts(): from modules import shared @@ -26,7 +27,7 @@ def create_deepbooru_opts(): "use_spaces": shared.opts.deepbooru_use_spaces, "use_escape": shared.opts.deepbooru_escape, "alpha_sort": shared.opts.deepbooru_sort_alpha, - "include_ranks": shared.opts.interrogate_return_ranks, + OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks, } diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 3047bede..886cf0c3 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -17,7 +17,9 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ shared.interrogator.load() if process_caption_deepbooru: - deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts()) + db_opts = deepbooru.create_deepbooru_opts() + db_opts[deepbooru.OPT_INCLUDE_RANKS] = False + deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) -- cgit v1.2.3 From 514456101b142b47acf87f6de95bad1a23d73be7 Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Wed, 12 Oct 2022 13:14:13 -0700 Subject: [3/?] [wip] fix incorrect variable reference still needs testing --- modules/deepbooru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index fcc05819..c2004696 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -164,7 +164,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o if use_escape: tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) if include_ranks: - use_escape += f":{weight:.3f}" + tag_outformat += f":{weight:.3f}" result_tags_out.append(tag_outformat) -- cgit v1.2.3 From 1cfc2a18981ee56bdb69a2de7b463a11ad05e329 Mon Sep 17 00:00:00 2001 From: Melan Date: Wed, 12 Oct 2022 23:36:29 +0200 Subject: Save a csv containing the loss while training --- modules/hypernetworks/hypernetwork.py | 17 ++++++++++++++++- modules/textual_inversion/textual_inversion.py | 17 ++++++++++++++++- modules/ui.py | 3 +++ 3 files changed, 35 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b6c06d49..6522078f 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -5,6 +5,7 @@ import os import sys import traceback import tqdm +import csv import torch @@ -174,7 +175,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): return self.to_out(out) -def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): +def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, write_csv_every, template_file, preview_image_prompt): assert hypernetwork_name, 'hypernetwork not selected' path = shared.hypernetworks.get(hypernetwork_name, None) @@ -256,6 +257,20 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') hypernetwork.save(last_saved_file) + print(f"{write_csv_every} > {hypernetwork.step % write_csv_every == 0}, {write_csv_every}") + if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0: + write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True + + with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout: + + csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss"]) + + if write_csv_header: + csv_writer.writeheader() + + csv_writer.writerow({"step": hypernetwork.step, + "loss": f"{losses.mean():.7f}"}) + if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fa0e33a2..25038a89 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -6,6 +6,7 @@ import torch import tqdm import html import datetime +import csv from PIL import Image, PngImagePlugin @@ -172,7 +173,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, write_csv_every, template_file, save_image_with_stored_embedding, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -256,6 +257,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') embedding.save(last_saved_file) + if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0: + write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True + + with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout: + + csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss"]) + + if write_csv_header: + csv_writer.writeheader() + + csv_writer.writerow({"epoch": epoch_num + 1, + "epoch_step": epoch_step - 1, + "loss": f"{losses.mean():.7f}"}) + if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') diff --git a/modules/ui.py b/modules/ui.py index e07ee0e1..1195c2f1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1096,6 +1096,7 @@ def create_ui(wrap_gradio_gpu_call): training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) + write_csv_every = gr.Number(label='Save an csv containing the loss to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) preview_image_prompt = gr.Textbox(label='Preview prompt', value="") @@ -1174,6 +1175,7 @@ def create_ui(wrap_gradio_gpu_call): steps, create_image_every, save_embedding_every, + write_csv_every, template_file, save_image_with_stored_embedding, preview_image_prompt, @@ -1195,6 +1197,7 @@ def create_ui(wrap_gradio_gpu_call): steps, create_image_every, save_embedding_every, + write_csv_every, template_file, preview_image_prompt, ], -- cgit v1.2.3 From 54e0051bdd7dea7348825c09600ec61ea0771cb8 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Wed, 12 Oct 2022 18:17:26 -0500 Subject: Add drag/drop param loading. Drop an image or generational text onto the prompt bar, it loads the info for parsing. --- modules/images.py | 20 ++++++++++++++++++++ modules/ui.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index c0a90676..f1155b7f 100644 --- a/modules/images.py +++ b/modules/images.py @@ -463,3 +463,23 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i txt_fullfn = None return fullfn, txt_fullfn + + +def image_data(image_path): + file, ext = os.path.splitext(image_path.name) + data = {} + if "png" in ext: + image = Image.open(image_path.name, "r") + print(f"Image data requested for {image_path.name} {image.format} of {type(image)}") + try: + data = image.text["parameters"] + except Exception as e: + print(f"Exception: {e}") + pass + print(f"Image data: {data}") + if "txt" in ext: + myfile = open(image_path.name, 'r') + data = myfile.read() + myfile.close() + + return data, None diff --git a/modules/ui.py b/modules/ui.py index 2b332267..dd793c39 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -431,7 +431,6 @@ def create_toprow(is_img2img): with gr.Column(scale=80): with gr.Row(): prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2) - with gr.Column(scale=1, elem_id="roll_col"): roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) paste = gr.Button(value=paste_symbol, elem_id="paste") @@ -513,6 +512,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="file", visible=False) with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): @@ -614,6 +614,18 @@ def create_ui(wrap_gradio_gpu_call): txt2img_prompt.submit(**txt2img_args) submit.click(**txt2img_args) + txt_prompt_img.change( + fn=modules.images.image_data, + # _js = "get_extras_tab_index", + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ] + ) + enable_hr.change( fn=lambda x: gr_show(x), inputs=[enable_hr], @@ -674,6 +686,9 @@ def create_ui(wrap_gradio_gpu_call): img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): + img2img_prompt_img = gr.File(label="", elem_id="txt_prompt_image", file_count="single", type="file", + visible=False) + with gr.Column(scale=1): pass @@ -768,6 +783,18 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + img2img_prompt_img.change( + fn=modules.images.image_data, + # _js = "get_extras_tab_index", + inputs=[ + txt_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ] + ) + mask_mode.change( lambda mode, img: { init_img_with_mask: gr_show(mode == 0), @@ -956,6 +983,7 @@ def create_ui(wrap_gradio_gpu_call): button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else '' open_extras_folder = gr.Button('Open output directory', elem_id=button_id) + submit.click( fn=wrap_gradio_gpu_call(modules.extras.run_extras), _js="get_extras_tab_index", -- cgit v1.2.3 From 716a9e034f1aff434083363b218bd6043a774fc2 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Thu, 13 Oct 2022 12:19:50 +0800 Subject: images history delete a number of images consecutively next --- modules/images_history.py | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 6408973c..f812ea4e 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -54,23 +54,26 @@ def show_image_info(num, image_path, filenames): #print(f"select image {num}") file = filenames[int(num)] return file, num, os.path.join(image_path, file) -def delete_image(tabname, dir_name, name, page_index, filenames, image_index): - path = os.path.join(dir_name, name) - if os.path.exists(path): - print(f"Delete file {path}") - os.remove(path) - txt_file = os.path.splitext(path)[0] + ".txt" - if os.path.exists(txt_file): - os.remove(txt_file) - new_file_list = [] - for f in filenames: - if f == name: - continue - new_file_list.append(f) - else: - print(f"Not exists file {path}") - new_file_list = filenames - return page_index, new_file_list +def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index): + delete_num = int(delete_num) + index = list(filenames).index(name) + i = 0 + new_file_list = [] + for name in filenames: + if i >= index and i < index + delete_num: + path = os.path.join(dir_name, name) + if os.path.exists(path): + print(f"Delete file {path}") + os.remove(path) + txt_file = os.path.splitext(path)[0] + ".txt" + if os.path.exists(txt_file): + os.remove(txt_file) + else: + print(f"Not exists file {path}") + else: + new_file_list.append(name) + i += 1 + return page_index, new_file_list, 1 def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): if tabname == "txt2img": @@ -90,10 +93,11 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Row(): with gr.Column(scale=2): history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) - delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") + with gr.Row(): + delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") + delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") with gr.Column(): with gr.Row(): - #pnginfo = gr.Button('PNG info') pnginfo_send_to_txt2img = gr.Button('Send to txt2img') pnginfo_send_to_img2img = gr.Button('Send to img2img') with gr.Row(): @@ -127,7 +131,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): #other funcitons set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hide_image]) img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) - delete.click(delete_image,_js="images_history_delete", inputs=[tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[page_index, filenames]) + delete.click(delete_image,_js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[page_index, filenames, delete_num]) hide_image.change(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) #pnginfo.click(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) -- cgit v1.2.3 From 78592d404acba7db3baf8d78bdc19266906e684a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 07:40:03 +0300 Subject: remove interrogate option I accidentally deleted --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 78b73aae..9bda45c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -258,6 +258,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), + "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), -- cgit v1.2.3 From 04c0e643f2eec68d93a76db171b4d70595808702 Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Wed, 12 Oct 2022 22:13:53 -0700 Subject: Merge branch 'master' of https://github.com/HunterVacui/stable-diffusion-webui --- modules/deepbooru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index c2004696..f34f3788 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -164,7 +164,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o if use_escape: tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) if include_ranks: - tag_outformat += f":{weight:.3f}" + tag_outformat = f"({tag_outformat}:{weight:.3f})" result_tags_out.append(tag_outformat) -- cgit v1.2.3 From e72adc999b3531370eafb9d316924ac497feb445 Mon Sep 17 00:00:00 2001 From: Trung Ngo Date: Sat, 8 Oct 2022 22:57:19 -0500 Subject: Restore last generation params --- modules/generation_parameters_copypaste.py | 8 ++++++++ modules/processing.py | 4 ++++ 2 files changed, 12 insertions(+) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index ac1ba7f4..3e75aecc 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,5 +1,7 @@ +import os import re import gradio as gr +from modules.shared import script_path re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" re_param = re.compile(re_param_code) @@ -61,6 +63,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model def connect_paste(button, paste_fields, input_comp, js=None): def paste_func(prompt): + if not prompt: + filename = os.path.join(script_path, "params.txt") + if os.path.exists(filename): + with open(filename, "r", encoding="utf8") as file: + prompt = file.read() + params = parse_generation_parameters(prompt) res = [] diff --git a/modules/processing.py b/modules/processing.py index 698b3069..d5172f00 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -324,6 +324,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: else: assert p.prompt is not None + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + processed = Processed(p, [], p.seed, "") + file.write(processed.infotext(p, 0)) + devices.torch_gc() seed = get_fixed_seed(p.seed) -- cgit v1.2.3 From fde7fefa2ea23747f1107e3e46bf60c08a1134f1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 12:26:34 +0300 Subject: update #2336 to prevent reading params.txt when --hide-ui-dir-config option is enabled (for servers, since this will let some users access others' params) --- modules/generation_parameters_copypaste.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 3e75aecc..c27826b6 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -2,6 +2,7 @@ import os import re import gradio as gr from modules.shared import script_path +from modules import shared re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" re_param = re.compile(re_param_code) @@ -63,7 +64,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model def connect_paste(button, paste_fields, input_comp, js=None): def paste_func(prompt): - if not prompt: + if not prompt and not shared.cmd_opts.hide_ui_dir_config: filename = os.path.join(script_path, "params.txt") if os.path.exists(filename): with open(filename, "r", encoding="utf8") as file: -- cgit v1.2.3 From aeacbac218c47f61f1d0d3f3b429c9038b8faf0f Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Tue, 11 Oct 2022 19:46:33 -0700 Subject: Fix save error --- modules/ui.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index e07ee0e1..4fa405a9 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -148,7 +148,10 @@ def save_files(js_data, images, do_make_zip, index): is_grid = image_index < p.index_of_first_image i = 0 if is_grid else (image_index - p.index_of_first_image) - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + seed = p.all_seeds[i] if len(p.all_seeds) > 1 else p.seed + prompt = p.all_prompts[i] if len(p.all_prompts) > 1 else p.prompt + info = p.infotexts[image_index] if len(p.infotexts) > 1 else p.infotexts[0] + fullfn, txt_fullfn = save_image(image, path, "", seed=seed, prompt=prompt, extension=extension, info=info, grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) filenames.append(filename) -- cgit v1.2.3 From 8711c2fe0135d5c160a57db41cb79ed1942ce7fa Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Wed, 12 Oct 2022 16:12:12 -0700 Subject: Fix metadata contents --- modules/ui.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 4fa405a9..e07ee0e1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -148,10 +148,7 @@ def save_files(js_data, images, do_make_zip, index): is_grid = image_index < p.index_of_first_image i = 0 if is_grid else (image_index - p.index_of_first_image) - seed = p.all_seeds[i] if len(p.all_seeds) > 1 else p.seed - prompt = p.all_prompts[i] if len(p.all_prompts) > 1 else p.prompt - info = p.infotexts[image_index] if len(p.infotexts) > 1 else p.infotexts[0] - fullfn, txt_fullfn = save_image(image, path, "", seed=seed, prompt=prompt, extension=extension, info=info, grid=is_grid, p=p, save_to_dirs=save_to_dirs) + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) filenames.append(filename) -- cgit v1.2.3 From a3f02e4690844715a510b7bc857a0971dd05c4d8 Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Wed, 12 Oct 2022 16:48:53 -0700 Subject: fix prompt in log.csv --- modules/ui.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index e07ee0e1..edb4dab1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -139,6 +139,8 @@ def save_files(js_data, images, do_make_zip, index): if at_start: writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + log_prompt=data["prompt"] + log_seed=data["seed"] for image_index, filedata in enumerate(images, start_index): if filedata.startswith("data:image/png;base64,"): filedata = filedata[len("data:image/png;base64,"):] @@ -148,7 +150,9 @@ def save_files(js_data, images, do_make_zip, index): is_grid = image_index < p.index_of_first_image i = 0 if is_grid else (image_index - p.index_of_first_image) - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + log_seed=p.all_seeds[i] + log_prompt=p.all_prompts[i] + fullfn, txt_fullfn = save_image(image, path, "", seed=log_seed, prompt=log_prompt, extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) filenames.append(filename) @@ -157,7 +161,7 @@ def save_files(js_data, images, do_make_zip, index): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + writer.writerow([log_prompt, log_seed, data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) # Make Zip if do_make_zip: -- cgit v1.2.3 From fed7f0e281a42ea962bbe422e018468bafa6f1e6 Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Wed, 12 Oct 2022 23:09:30 -0700 Subject: Revert "fix prompt in log.csv" This reverts commit e4b5d1696429ab78dae9779420ce6ec4cd9c5f67. --- modules/ui.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index edb4dab1..e07ee0e1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -139,8 +139,6 @@ def save_files(js_data, images, do_make_zip, index): if at_start: writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - log_prompt=data["prompt"] - log_seed=data["seed"] for image_index, filedata in enumerate(images, start_index): if filedata.startswith("data:image/png;base64,"): filedata = filedata[len("data:image/png;base64,"):] @@ -150,9 +148,7 @@ def save_files(js_data, images, do_make_zip, index): is_grid = image_index < p.index_of_first_image i = 0 if is_grid else (image_index - p.index_of_first_image) - log_seed=p.all_seeds[i] - log_prompt=p.all_prompts[i] - fullfn, txt_fullfn = save_image(image, path, "", seed=log_seed, prompt=log_prompt, extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) filenames.append(filename) @@ -161,7 +157,7 @@ def save_files(js_data, images, do_make_zip, index): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([log_prompt, log_seed, data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) # Make Zip if do_make_zip: -- cgit v1.2.3 From 8636b50aea83f9c743f005722d9f3f8ee9303e00 Mon Sep 17 00:00:00 2001 From: Melan Date: Thu, 13 Oct 2022 12:37:58 +0200 Subject: Add learn_rate to csv and removed a left-over debug statement --- modules/hypernetworks/hypernetwork.py | 6 +++--- modules/textual_inversion/textual_inversion.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 6522078f..2751a8c8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -257,19 +257,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') hypernetwork.save(last_saved_file) - print(f"{write_csv_every} > {hypernetwork.step % write_csv_every == 0}, {write_csv_every}") if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0: write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout: - csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss"]) + csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"]) if write_csv_header: csv_writer.writeheader() csv_writer.writerow({"step": hypernetwork.step, - "loss": f"{losses.mean():.7f}"}) + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate}) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 25038a89..b83df079 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -262,14 +262,15 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout: - csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss"]) + csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss", "learn_rate"]) if write_csv_header: csv_writer.writeheader() csv_writer.writerow({"epoch": epoch_num + 1, "epoch_step": epoch_step - 1, - "loss": f"{losses.mean():.7f}"}) + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate}) if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') -- cgit v1.2.3 From bb7baf6b9cb6b4b9fa09b6f07ef997db32fe6e58 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 16:07:18 +0300 Subject: add option to change what's shown in quicksettings bar --- modules/shared.py | 4 ++-- modules/ui.py | 16 +++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 5f6101a4..4d3ed625 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -152,7 +152,6 @@ class OptionInfo: self.component_args = component_args self.onchange = onchange self.section = None - self.show_on_main_page = show_on_main_page def options_section(section_identifier, options_dict): @@ -237,7 +236,7 @@ options_templates.update(options_section(('training', "Training"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), @@ -250,6 +249,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), + 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { diff --git a/modules/ui.py b/modules/ui.py index e07ee0e1..a0529860 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1305,6 +1305,9 @@ Requested path was: {f} settings_cols = 3 items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') + quicksettings_list = [] cols_displayed = 0 @@ -1329,7 +1332,7 @@ Requested path was: {f} gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='

{}

'.format(item.section[1])) - if item.show_on_main_page: + if k in quicksettings_names: quicksettings_list.append((i, k, item)) components.append(dummy_component) else: @@ -1338,7 +1341,11 @@ Requested path was: {f} components.append(component) items_displayed += 1 - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + with gr.Row(): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') + request_notifications.click( fn=lambda: None, inputs=[], @@ -1346,10 +1353,6 @@ Requested path was: {f} _js='function(){}' ) - with gr.Row(): - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - def reload_scripts(): modules.scripts.reload_script_body_only() @@ -1364,7 +1367,6 @@ Requested path was: {f} shared.state.interrupt() settings_interface.gradio_ref.do_restart = True - restart_gradio.click( fn=request_restart, inputs=[], -- cgit v1.2.3 From a10b0e11fc22cc67b6a3664f2ddd17425d8433a8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 19:22:41 +0300 Subject: options to refresh list of models and hypernetworks --- modules/shared.py | 9 +++++---- modules/ui.py | 33 +++++++++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 4d3ed625..d8e3a286 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers +from modules import sd_samplers, sd_models from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -145,13 +145,14 @@ def realesrgan_models_names(): class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None): self.default = default self.label = label self.component = component self.component_args = component_args self.onchange = onchange self.section = None + self.refresh = refresh def options_section(section_identifier, options_dict): @@ -236,8 +237,8 @@ options_templates.update(options_section(('training', "Training"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), - "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), + "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), diff --git a/modules/ui.py b/modules/ui.py index a0529860..0a58f6be 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -78,6 +78,8 @@ reuse_symbol = '\u267b\ufe0f' # ♻️ art_symbol = '\U0001f3a8' # 🎨 paste_symbol = '\u2199\ufe0f' # ↙ folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 + def plaintext_to_html(text): text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" @@ -1210,8 +1212,7 @@ def create_ui(wrap_gradio_gpu_call): outputs=[], ) - - def create_setting_component(key): + def create_setting_component(key, is_quicksettings=False): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -1231,7 +1232,31 @@ def create_ui(wrap_gradio_gpu_call): else: raise Exception(f'bad options item type: {str(t)} for key {key}') - return comp(label=info.label, value=fun, **(args or {})) + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun, **(args or {})) + refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key) + else: + with gr.Row(variant="compact"): + res = comp(label=info.label, value=fun, **(args or {})) + refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key) + + def refresh(): + info.refresh() + refreshed_args = info.component_args() if callable(info.component_args) else info.component_args + res.choices = refreshed_args["choices"] + return gr.update(**(refreshed_args or {})) + + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[res], + ) + else: + res = comp(label=info.label, value=fun, **(args or {})) + + + return res components = [] component_dict = {} @@ -1401,7 +1426,7 @@ Requested path was: {f} with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings"): for i, k, item in quicksettings_list: - component = create_setting_component(k) + component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component settings_interface.gradio_ref = demo -- cgit v1.2.3 From 354ef0da3b1f0fa5c113d04b6c79e3908c848d23 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 20:12:37 +0300 Subject: add hypernetwork multipliers --- modules/hypernetworks/hypernetwork.py | 8 +++++++- modules/shared.py | 5 ++++- modules/ui.py | 5 ++++- 3 files changed, 15 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b6c06d49..f1248bb7 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): + multiplier = 1.0 + def __init__(self, dim, state_dict=None): super().__init__() @@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module): self.to(devices.device) def forward(self, x): - return x + (self.linear2(self.linear1(x))) + return x + (self.linear2(self.linear1(x))) * self.multiplier + + +def apply_strength(value=None): + HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength class Hypernetwork: diff --git a/modules/shared.py b/modules/shared.py index d8e3a286..5901e605 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -238,7 +238,8 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), - "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), @@ -348,6 +349,8 @@ class Options: item = self.data_labels.get(key) item.onchange = func + func() + def dumpjson(self): d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} return json.dumps(d) diff --git a/modules/ui.py b/modules/ui.py index 0a58f6be..673014f2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1244,7 +1244,10 @@ def create_ui(wrap_gradio_gpu_call): def refresh(): info.refresh() refreshed_args = info.component_args() if callable(info.component_args) else info.component_args - res.choices = refreshed_args["choices"] + + for k, v in refreshed_args.items(): + setattr(res, k, v) + return gr.update(**(refreshed_args or {})) refresh_button.click( -- cgit v1.2.3 From 08b3f7aef15f74f4d2254b1274dd66fcc7940348 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 20:42:27 +0300 Subject: emergency fix for broken send to buttons --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 673014f2..7446439d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1434,7 +1434,7 @@ Requested path was: {f} settings_interface.gradio_ref = demo - with gr.Tabs() as tabs: + with gr.Tabs(elem_id="tabs") as tabs: for interface, label, ifid in interfaces: with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): interface.render() -- cgit v1.2.3 From a1489f94283c07824a7a58353c03dc89541bbe49 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Fri, 14 Oct 2022 07:13:38 +0800 Subject: images history fix all known bug --- modules/images_history.py | 51 +++++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 24 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index f812ea4e..cdfcffed 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -38,7 +38,7 @@ def get_recent_images(dir_name, page_index, step, image_index, tabname): else: current_file = file_list[int(image_index)] hide_image = os.path.join(dir_name, current_file) - return [os.path.join(dir_name, file) for file in file_list], page_index, file_list, current_file, hide_image + return [os.path.join(dir_name, file) for file in file_list], page_index, file_list, current_file, hide_image, "" def first_page_click(dir_name, page_index, image_index, tabname): return get_recent_images(dir_name, 1, 0, image_index, tabname) def end_page_click(dir_name, page_index, image_index, tabname): @@ -55,25 +55,28 @@ def show_image_info(num, image_path, filenames): file = filenames[int(num)] return file, num, os.path.join(image_path, file) def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index): - delete_num = int(delete_num) - index = list(filenames).index(name) - i = 0 - new_file_list = [] - for name in filenames: - if i >= index and i < index + delete_num: - path = os.path.join(dir_name, name) - if os.path.exists(path): - print(f"Delete file {path}") - os.remove(path) - txt_file = os.path.splitext(path)[0] + ".txt" - if os.path.exists(txt_file): - os.remove(txt_file) + if name == "": + return filenames, delete_num + else: + delete_num = int(delete_num) + index = list(filenames).index(name) + i = 0 + new_file_list = [] + for name in filenames: + if i >= index and i < index + delete_num: + path = os.path.join(dir_name, name) + if os.path.exists(path): + print(f"Delete file {path}") + os.remove(path) + txt_file = os.path.splitext(path)[0] + ".txt" + if os.path.exists(txt_file): + os.remove(txt_file) + else: + print(f"Not exists file {path}") else: - print(f"Not exists file {path}") - else: - new_file_list.append(name) - i += 1 - return page_index, new_file_list, 1 + new_file_list.append(name) + i += 1 + return new_file_list, 1 def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): if tabname == "txt2img": @@ -93,9 +96,9 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Row(): with gr.Column(scale=2): history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) - with gr.Row(): - delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") + with gr.Row(): + delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") + delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") with gr.Column(): with gr.Row(): pnginfo_send_to_txt2img = gr.Button('Send to txt2img') @@ -118,7 +121,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): # turn pages gallery_inputs = [img_path, page_index, image_index, tabname_box] - gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hide_image] + gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hide_image, img_file_name] first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) @@ -131,7 +134,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): #other funcitons set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hide_image]) img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) - delete.click(delete_image,_js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[page_index, filenames, delete_num]) + delete.click(delete_image,_js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) hide_image.change(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) #pnginfo.click(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) -- cgit v1.2.3 From 4a37c7eedeab579efec03e8dae3f3f9fd4a37b02 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Fri, 14 Oct 2022 11:48:28 +0800 Subject: fix deep nesting directories problem --- modules/images_history.py | 76 ++++++++++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 34 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index cdfcffed..723f5301 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -1,44 +1,47 @@ import os import shutil -def get_recent_images(dir_name, page_index, step, image_index, tabname): - #print(f"renew page {page_index}") - page_index = int(page_index) - f_list = os.listdir(dir_name) - file_list = [] +def traverse_all_files(output_dir, image_list, curr_dir=None): + curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir) + try: + f_list = os.listdir(curr_path) + except: + if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt": + image_list.append(curr_dir) + return image_list for file in f_list: + file = file if curr_dir is None else os.path.join(curr_dir, file) + file_path = os.path.join(curr_path, file) if file[-4:] == ".txt": - continue - #subdirectories - if file[-10:].rfind(".") < 0: - sub_dir = os.path.join(dir_name, file) - if os.path.isfile(sub_dir): - continue - sub_file_list = os.listdir(sub_dir) - for sub_file in sub_file_list: - if sub_file[-4:] == ".txt": - continue - if os.path.isfile(os.path.join(sub_dir, sub_file) ): - file_list.append(os.path.join(file, sub_file)) - continue - file_list.append(file) + pass + elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0: + image_list.append(file) + else: + image_list = traverse_all_files(output_dir, image_list, file) + return image_list - file_list = sorted(file_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) + +def get_recent_images(dir_name, page_index, step, image_index, tabname): + page_index = int(page_index) + f_list = os.listdir(dir_name) + image_list = [] + image_list = traverse_all_files(dir_name, image_list) + image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) num = 48 if tabname != "extras" else 12 - max_page_index = len(file_list) // num + 1 + max_page_index = len(image_list) // num + 1 page_index = max_page_index if page_index == -1 else page_index + step page_index = 1 if page_index < 1 else page_index page_index = max_page_index if page_index > max_page_index else page_index idx_frm = (page_index - 1) * num - file_list = file_list[idx_frm:idx_frm + num] - #print(f"Loading history page {page_index}") + image_list = image_list[idx_frm:idx_frm + num] image_index = int(image_index) - if image_index < 0 or image_index > len(file_list) - 1: + if image_index < 0 or image_index > len(image_list) - 1: current_file = None - hide_image = None + hidden = None else: - current_file = file_list[int(image_index)] - hide_image = os.path.join(dir_name, current_file) - return [os.path.join(dir_name, file) for file in file_list], page_index, file_list, current_file, hide_image, "" + current_file = image_list[int(image_index)] + hidden = os.path.join(dir_name, current_file) + return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, "" + def first_page_click(dir_name, page_index, image_index, tabname): return get_recent_images(dir_name, 1, 0, image_index, tabname) def end_page_click(dir_name, page_index, image_index, tabname): @@ -85,6 +88,10 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): dir_name = opts.outdir_img2img_samples elif tabname == "extras": dir_name = opts.outdir_extras_samples + d = dir_name.split("/") + dir_name = d[0] + for p in d[1:]: + dir_name = os.path.join(dir_name, p) with gr.Row(): renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page") first_page = gr.Button('First Page') @@ -109,19 +116,20 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): img_file_name = gr.Textbox(label="File Name", interactive=False) with gr.Row(): # hiden items - img_path = gr.Textbox(dir_name, visible=False) + + img_path = gr.Textbox(dir_name.rstrip("/") , visible=False) tabname_box = gr.Textbox(tabname, visible=False) image_index = gr.Textbox(value=-1, visible=False) set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) filenames = gr.State() - hide_image = gr.Image(type="pil", visible=False) + hidden = gr.Image(type="pil", visible=False) info1 = gr.Textbox(visible=False) info2 = gr.Textbox(visible=False) # turn pages gallery_inputs = [img_path, page_index, image_index, tabname_box] - gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hide_image, img_file_name] + gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name] first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) @@ -132,12 +140,12 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): #page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index]) #other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hide_image]) + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden]) img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) delete.click(delete_image,_js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) - hide_image.change(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) + hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) - #pnginfo.click(fn=run_pnginfo, inputs=[hide_image], outputs=[info1, img_file_info, info2]) + #pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') -- cgit v1.2.3 From fdecb636855748e03efc40c846a0043800aadfcc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 09:05:06 +0300 Subject: add an ability to merge three checkpoints --- modules/extras.py | 29 +++++++++++++++++++++-------- modules/ui.py | 11 +++++++---- 2 files changed, 28 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index b24d7de3..532d869f 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -159,48 +159,61 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name): +def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, interp_amount, save_as_half, custom_name): # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) - def weighted_sum(theta0, theta1, alpha): + def weighted_sum(theta0, theta1, theta2, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def sigmoid(theta0, theta1, alpha): + def sigmoid(theta0, theta1, theta2, alpha): alpha = alpha * alpha * (3 - (2 * alpha)) return theta0 + ((theta1 - theta0) * alpha) # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def inv_sigmoid(theta0, theta1, alpha): + def inv_sigmoid(theta0, theta1, theta2, alpha): import math alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) return theta0 + ((theta1 - theta0) * alpha) + def add_difference(theta0, theta1, theta2, alpha): + return theta0 + (theta1 - theta2) * (1.0 - alpha) + primary_model_info = sd_models.checkpoints_list[primary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name] + teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None) print(f"Loading {primary_model_info.filename}...") primary_model = torch.load(primary_model_info.filename, map_location='cpu') + theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) print(f"Loading {secondary_model_info.filename}...") secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') - - theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model) + if teritary_model_info is not None: + print(f"Loading {teritary_model_info.filename}...") + teritary_model = torch.load(teritary_model_info.filename, map_location='cpu') + theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model) + else: + theta_2 = None + theta_funcs = { "Weighted Sum": weighted_sum, "Sigmoid": sigmoid, "Inverse Sigmoid": inv_sigmoid, + "Add difference": add_difference, } theta_func = theta_funcs[interp_method] print(f"Merging...") + for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: - theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key] if theta_2 else None, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint if save_as_half: theta_0[key] = theta_0[key].half() + # I believe this part should be discarded, but I'll leave it for now until I am sure for key in theta_1.keys(): if 'model' in key and key not in theta_0: theta_0[key] = theta_1[key] @@ -219,4 +232,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int sd_models.list_models() print(f"Checkpoint saved.") - return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)] + return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] diff --git a/modules/ui.py b/modules/ui.py index 7446439d..220fb80b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1024,11 +1024,12 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3) - interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation amount (1 - M)', value=0.3) + interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid", "Add difference"], value="Weighted Sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Save as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') @@ -1473,6 +1474,7 @@ Requested path was: {f} inputs=[ primary_model_name, secondary_model_name, + tertiary_model_name, interp_method, interp_amount, save_as_half, @@ -1482,6 +1484,7 @@ Requested path was: {f} submit_result, primary_model_name, secondary_model_name, + tertiary_model_name, component_dict['sd_model_checkpoint'], ] ) -- cgit v1.2.3 From bb57f30c2de46cfca5419ad01738a41705f96cc3 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Fri, 14 Oct 2022 10:56:41 +0200 Subject: init --- modules/processing.py | 17 +++++- modules/sd_hijack.py | 80 +++++++++++++++++++++++++- modules/shared.py | 5 ++ modules/textual_inversion/dataset.py | 2 +- modules/textual_inversion/textual_inversion.py | 35 +++++++---- modules/txt2img.py | 11 +++- modules/ui.py | 59 ++++++++++++------- 7 files changed, 171 insertions(+), 38 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index d5172f00..9a033759 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -316,11 +316,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() -def process_images(p: StableDiffusionProcessing) -> Processed: +def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, + aesthetic_imgs=None,aesthetic_slerp=False) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" + aesthetic_lr = float(aesthetic_lr) + aesthetic_weight = float(aesthetic_weight) + aesthetic_steps = int(aesthetic_steps) + if type(p.prompt) == list: - assert(len(p.prompt) > 0) + assert (len(p.prompt) > 0) else: assert p.prompt is not None @@ -394,7 +399,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) + if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): + shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0) + uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], + p.steps) + if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): + shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight, + aesthetic_steps, aesthetic_imgs,aesthetic_slerp) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c81722a0..6d5196fe 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -9,11 +9,14 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared -from modules.shared import opts, device, cmd_opts +from modules.shared import opts, device, cmd_opts, aesthetic_embeddings from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model +from transformers import CLIPVisionModel, CLIPModel +import torch.optim as optim +import copy attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity @@ -109,13 +112,29 @@ class StableDiffusionModelHijack: _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) +def slerp(low, high, val): + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm*high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped + self.clipModel = CLIPModel.from_pretrained( + self.wrapped.transformer.name_or_path + ) + del self.clipModel.vision_model self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer + # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() + self.image_embs_name = None + self.image_embs = None + self.load_image_embs(None) + self.token_mults = {} self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] @@ -136,6 +155,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult + def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None, + aesthetic_slerp=True): + self.slerp = aesthetic_slerp + self.aesthetic_lr = aesthetic_lr + self.aesthetic_weight = aesthetic_weight + self.aesthetic_steps = aesthetic_steps + self.load_image_embs(image_embs_name) + + def load_image_embs(self, image_embs_name): + if image_embs_name is None or len(image_embs_name) == 0: + image_embs_name = None + if image_embs_name is not None and self.image_embs_name != image_embs_name: + self.image_embs_name = image_embs_name + self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device) + self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) + self.image_embs.requires_grad_(False) + def tokenize_line(self, line, used_custom_terms, hijack_comments): id_end = self.wrapped.tokenizer.eos_token_id @@ -333,7 +369,47 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - + + if len(text[ + 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [ + [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in + remade_batch_tokens] + + tokens = torch.asarray(remade_batch_tokens).to(device) + with torch.enable_grad(): + model = copy.deepcopy(self.clipModel).to(device) + model.requires_grad_(True) + + # We optimize the model to maximize the similarity + optimizer = optim.Adam( + model.text_model.parameters(), lr=self.aesthetic_lr + ) + + for i in range(self.aesthetic_steps): + text_embs = model.get_text_features(input_ids=tokens) + text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) + sim = text_embs @ self.image_embs.T + loss = -sim + optimizer.zero_grad() + loss.mean().backward() + optimizer.step() + + zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + if opts.CLIP_stop_at_last_layers > 1: + zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] + zn = model.text_model.final_layer_norm(zn) + else: + zn = zn.last_hidden_state + model.cpu() + del model + + if self.slerp: + z = slerp(z, zn, self.aesthetic_weight) + else: + z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight + remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 diff --git a/modules/shared.py b/modules/shared.py index 5901e605..cf13a10d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -30,6 +30,8 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(script_path, 'aesthetic_embeddings'), + help="aesthetic_embeddings directory(default: aesthetic_embeddings)") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") @@ -90,6 +92,9 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None +aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in + os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} + def reload_hypernetworks(): global hypernetworks diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 67e90afe..59b2b021 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -48,7 +48,7 @@ class PersonalizedBase(Dataset): print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) + image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC) except Exception: continue diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fa0e33a2..b12a8e6d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -172,7 +172,15 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): +def batched(dataset, total, n=1): + for ndx in range(0, total, n): + yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] + + +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, + create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, + preview_image_prompt, batch_size=1, + gradient_accumulation=1): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -204,7 +212,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, + height=training_height, + repeats=shared.opts.training_image_repeats_per_epoch, + placeholder_token=embedding_name, model=shared.sd_model, + device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -223,7 +235,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + pbar = tqdm.tqdm(enumerate(batched(ds, steps - ititial_step, batch_size)), total=steps - ititial_step) for i, entry in pbar: embedding.step = i + ititial_step @@ -235,17 +247,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini break with torch.autocast("cuda"): - c = cond_model([entry.cond_text]) + c = cond_model([e.cond_text for e in entry]) + + x = torch.stack([e.latent for e in entry]).to(devices.device) + loss = shared.sd_model(x, c)[0] - x = entry.latent.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] del x losses[embedding.step % losses.shape[0]] = loss.item() - optimizer.zero_grad() loss.backward() - optimizer.step() + if ((i + 1) % gradient_accumulation == 0) or (i + 1 == steps - ititial_step): + optimizer.step() + optimizer.zero_grad() + epoch_num = embedding.step // len(ds) epoch_step = embedding.step - (epoch_num * len(ds)) + 1 @@ -259,7 +274,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') - preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt + preview_text = entry[0].cond_text if preview_image_prompt == "" else preview_image_prompt p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -305,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini

Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entry[-1].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/txt2img.py b/modules/txt2img.py index e985242b..78342024 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -6,7 +6,14 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, + restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, + subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, + height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, + aesthetic_lr=0, + aesthetic_weight=0, aesthetic_steps=0, + aesthetic_imgs=None, + aesthetic_slerp=False, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -40,7 +47,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: - processed = process_images(p) + processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp) shared.total_tqdm.clear() diff --git a/modules/ui.py b/modules/ui.py index 220fb80b..d961d126 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -24,7 +24,8 @@ import gradio.routes from modules import sd_hijack from modules.paths import script_path -from modules.shared import opts, cmd_opts +from modules.shared import opts, cmd_opts,aesthetic_embeddings + if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags import modules.shared as shared @@ -534,6 +535,14 @@ def create_ui(wrap_gradio_gpu_call): width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + with gr.Group(): + aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7) + aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50) + + aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) + aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) tiling = gr.Checkbox(label='Tiling', value=False) @@ -586,25 +595,30 @@ def create_ui(wrap_gradio_gpu_call): fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), _js="submit", inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - scale_latent, - denoising_strength, - ] + custom_inputs, + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + scale_latent, + denoising_strength, + aesthetic_lr, + aesthetic_weight, + aesthetic_steps, + aesthetic_imgs, + aesthetic_slerp + ] + custom_inputs, outputs=[ txt2img_gallery, generation_info, @@ -1097,6 +1111,9 @@ def create_ui(wrap_gradio_gpu_call): template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + batch_size = gr.Slider(minimum=1, maximum=64, step=1, label="Batch Size", value=4) + gradient_accumulation = gr.Slider(minimum=1, maximum=256, step=1, label="Gradient accumulation", + value=1) steps = gr.Number(label='Max steps', value=100000, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) @@ -1180,6 +1197,8 @@ def create_ui(wrap_gradio_gpu_call): template_file, save_image_with_stored_embedding, preview_image_prompt, + batch_size, + gradient_accumulation ], outputs=[ ti_output, -- cgit v1.2.3 From fdef8253a43ca5135923092ca9b85e878d980869 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 14 Oct 2022 04:42:53 -0400 Subject: Add 'interrogate' and 'all' choices to --use-cpu * Add 'interrogate' and 'all' choices to --use-cpu * Change type for --use-cpu argument to str.lower, so that choices are case insensitive --- modules/devices.py | 2 +- modules/interrogate.py | 14 +++++++------- modules/shared.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 03ef58f1..eb422583 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,7 +34,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() +device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() dtype = torch.float16 dtype_vae = torch.float16 diff --git a/modules/interrogate.py b/modules/interrogate.py index af858cc0..9263d65a 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -55,7 +55,7 @@ class InterrogateModels: model, preprocess = clip.load(clip_model_name) model.eval() - model = model.to(shared.device) + model = model.to(devices.device_interrogate) return model, preprocess @@ -65,14 +65,14 @@ class InterrogateModels: if not shared.cmd_opts.no_half: self.blip_model = self.blip_model.half() - self.blip_model = self.blip_model.to(shared.device) + self.blip_model = self.blip_model.to(devices.device_interrogate) if self.clip_model is None: self.clip_model, self.clip_preprocess = self.load_clip_model() if not shared.cmd_opts.no_half: self.clip_model = self.clip_model.half() - self.clip_model = self.clip_model.to(shared.device) + self.clip_model = self.clip_model.to(devices.device_interrogate) self.dtype = next(self.clip_model.parameters()).dtype @@ -99,11 +99,11 @@ class InterrogateModels: text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] top_count = min(top_count, len(text_array)) - text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device) + text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate) text_features = self.clip_model.encode_text(text_tokens).type(self.dtype) text_features /= text_features.norm(dim=-1, keepdim=True) - similarity = torch.zeros((1, len(text_array))).to(shared.device) + similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate) for i in range(image_features.shape[0]): similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) similarity /= image_features.shape[0] @@ -116,7 +116,7 @@ class InterrogateModels: transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) + ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate) with torch.no_grad(): caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) @@ -140,7 +140,7 @@ class InterrogateModels: res = caption - clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) + clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate) precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext with torch.no_grad(), precision_scope("cuda"): diff --git a/modules/shared.py b/modules/shared.py index 5901e605..b6a5c1a8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -54,7 +54,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") -parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[]) +parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -76,8 +76,8 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl cmd_opts = parser.parse_args() -devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ -(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer']) +devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ +(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer']) device = devices.device -- cgit v1.2.3 From 9e5ca5077f43bb3ec1a0ec41b47964cb38d544a6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 16:37:32 +0300 Subject: extra message for unpicking fails --- modules/safe.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/safe.py b/modules/safe.py index 20be16a5..399165a1 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -96,11 +96,18 @@ def load(filename, *args, **kwargs): if not shared.cmd_opts.disable_safe_unpickle: check_pt(filename) + except pickle.UnpicklingError: + print(f"Error verifying pickled file from {filename}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) + print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) + return None + except Exception: print(f"Error verifying pickled file from {filename}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) - print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr) + print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) return None return unsafe_torch_load(filename, *args, **kwargs) -- cgit v1.2.3 From b2261b53ae4ad01b3713bc73ff62ab7b6f479e26 Mon Sep 17 00:00:00 2001 From: Buckzor Date: Thu, 13 Oct 2022 17:07:06 +0100 Subject: Added first_pass_width and height as adjustable inputs to "High Res Fix" --- modules/processing.py | 6 ++++-- modules/txt2img.py | 5 ++++- modules/ui.py | 6 ++++++ 3 files changed, 14 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index d5172f00..abbfdf98 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -506,11 +506,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): firstphase_width_truncated = 0 firstphase_height_truncated = 0 - def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs): + def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, first_pass_width=512, first_pass_height=512, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.scale_latent = scale_latent self.denoising_strength = denoising_strength + self.first_pass_width = first_pass_width + self.first_pass_height = first_pass_height def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: @@ -519,7 +521,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - desired_pixel_count = 512 * 512 + desired_pixel_count = self.first_pass_width * self.first_pass_height actual_pixel_count = self.width * self.height scale = math.sqrt(desired_pixel_count / actual_pixel_count) diff --git a/modules/txt2img.py b/modules/txt2img.py index e985242b..85cbece4 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -6,7 +6,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, first_pass_width: int, first_pass_height: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -32,6 +32,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: enable_hr=enable_hr, scale_latent=scale_latent if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None, + first_pass_width=first_pass_width if enable_hr else None, + first_pass_height=first_pass_height if enable_hr else None, + ) if cmd_opts.enable_console_prompts: diff --git a/modules/ui.py b/modules/ui.py index 220fb80b..544419b2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -540,6 +540,8 @@ def create_ui(wrap_gradio_gpu_call): enable_hr = gr.Checkbox(label='Highres. fix', value=False) with gr.Row(visible=False) as hr_options: + first_pass_width = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass width", value=512) + first_pass_height = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass height", value=512) scale_latent = gr.Checkbox(label='Scale latent', value=False) denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) @@ -604,6 +606,8 @@ def create_ui(wrap_gradio_gpu_call): enable_hr, scale_latent, denoising_strength, + first_pass_width, + first_pass_height, ] + custom_inputs, outputs=[ txt2img_gallery, @@ -668,6 +672,8 @@ def create_ui(wrap_gradio_gpu_call): (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (first_pass_width, "First pass width"), + (first_pass_height, "First pass height"), ] modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) -- cgit v1.2.3 From 40d1c6e423b4dc52b3bdae43d9e2442960760ced Mon Sep 17 00:00:00 2001 From: Buckzor Date: Thu, 13 Oct 2022 20:04:22 +0100 Subject: Option between stretch and crop for Highres. fix --- modules/processing.py | 34 ++++++++++++++++++++++------------ modules/txt2img.py | 7 ++++--- modules/ui.py | 25 ++++++++++++++++--------- 3 files changed, 42 insertions(+), 24 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index abbfdf98..0246f5dd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -506,13 +506,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): firstphase_width_truncated = 0 firstphase_height_truncated = 0 - def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, first_pass_width=512, first_pass_height=512, **kwargs): + def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, firstphase_width=512, firstphase_height=512, crop_scale=False, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.scale_latent = scale_latent self.denoising_strength = denoising_strength - self.first_pass_width = first_pass_width - self.first_pass_height = first_pass_height + self.firstphase_width = firstphase_width + self.firstphase_height = firstphase_height + self.crop_scale = crop_scale def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: @@ -521,14 +522,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - desired_pixel_count = self.first_pass_width * self.first_pass_height - actual_pixel_count = self.width * self.height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) + #desired_pixel_count = self.firstphase_width * self.firstphase_height + #actual_pixel_count = self.width * self.height + #scale = math.sqrt(desired_pixel_count / actual_pixel_count) - self.firstphase_width = math.ceil(scale * self.width / 64) * 64 - self.firstphase_height = math.ceil(scale * self.height / 64) * 64 - self.firstphase_width_truncated = int(scale * self.width) - self.firstphase_height_truncated = int(scale * self.height) + #self.firstphase_width = math.ceil(scale * self.width / 64) * 64 + #self.firstphase_height = math.ceil(scale * self.height / 64) * 64 + #self.firstphase_width_truncated = int(scale * self.width) + #self.firstphase_height_truncated = int(scale * self.height) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) @@ -541,8 +542,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) - truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f - truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f + truncate_x = 0 + truncate_y = 0 + + if self.crop_scale: + if self.width/self.firstphase_width > self.height/self.firstphase_height: + #Crop to landscape + truncate_y = (self.width - self.firstphase_width)//2 // opt_f + + elif self.width/self.firstphase_width < self.height/self.firstphase_height: + #Crop to portrait + truncate_x = (self.height - self.firstphase_height)//2 // opt_f samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2] diff --git a/modules/txt2img.py b/modules/txt2img.py index 85cbece4..447ec3d3 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -6,7 +6,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, first_pass_width: int, first_pass_height: int, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, crop_scale: bool, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -32,8 +32,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: enable_hr=enable_hr, scale_latent=scale_latent if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None, - first_pass_width=first_pass_width if enable_hr else None, - first_pass_height=first_pass_height if enable_hr else None, + firstphase_width=firstphase_width if enable_hr else None, + firstphase_height=firstphase_height if enable_hr else None, + crop_scale=crop_scale if enable_hr else None, ) diff --git a/modules/ui.py b/modules/ui.py index 544419b2..f2d81f68 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -540,12 +540,18 @@ def create_ui(wrap_gradio_gpu_call): enable_hr = gr.Checkbox(label='Highres. fix', value=False) with gr.Row(visible=False) as hr_options: - first_pass_width = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass width", value=512) - first_pass_height = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass height", value=512) - scale_latent = gr.Checkbox(label='Scale latent', value=False) - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) + with gr.Column(scale=1.0): + firstphase_width = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass width", value=512) + firstphase_height = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass height", value=512) + + with gr.Column(scale=1.0): + with gr.Row(): + crop_scale = gr.Checkbox(label='Crop when scaling', value=False) + scale_latent = gr.Checkbox(label='Scale latent', value=False) + with gr.Row(): + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) - with gr.Row(): + with gr.Row(equal_height=True): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) @@ -606,8 +612,9 @@ def create_ui(wrap_gradio_gpu_call): enable_hr, scale_latent, denoising_strength, - first_pass_width, - first_pass_height, + firstphase_width, + firstphase_height, + crop_scale, ] + custom_inputs, outputs=[ txt2img_gallery, @@ -672,8 +679,8 @@ def create_ui(wrap_gradio_gpu_call): (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (first_pass_width, "First pass width"), - (first_pass_height, "First pass height"), + (firstphase_width, "First pass width"), + (firstphase_height, "First pass height"), ] modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) -- cgit v1.2.3 From b382de2d77c653c565840ce92d27aa668a1934d7 Mon Sep 17 00:00:00 2001 From: Buckzor Date: Thu, 13 Oct 2022 22:23:22 +0100 Subject: Fixed Scale ratio problem --- modules/processing.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 0246f5dd..d9b0e0e7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -522,15 +522,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - #desired_pixel_count = self.firstphase_width * self.firstphase_height - #actual_pixel_count = self.width * self.height - #scale = math.sqrt(desired_pixel_count / actual_pixel_count) - - #self.firstphase_width = math.ceil(scale * self.width / 64) * 64 - #self.firstphase_height = math.ceil(scale * self.height / 64) * 64 - #self.firstphase_width_truncated = int(scale * self.width) - #self.firstphase_height_truncated = int(scale * self.height) - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) @@ -544,17 +535,23 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): truncate_x = 0 truncate_y = 0 + width_ratio = self.width/self.firstphase_width + height_ratio = self.height/self.firstphase_height if self.crop_scale: - if self.width/self.firstphase_width > self.height/self.firstphase_height: + if width_ratio > height_ratio: #Crop to landscape - truncate_y = (self.width - self.firstphase_width)//2 // opt_f + truncate_y = int((self.width - self.firstphase_width) / width_ratio / height_ratio / opt_f) - elif self.width/self.firstphase_width < self.height/self.firstphase_height: + elif width_ratio < height_ratio: #Crop to portrait - truncate_x = (self.height - self.firstphase_height)//2 // opt_f + truncate_x = int((self.height - self.firstphase_height) / width_ratio / height_ratio / opt_f) + + samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2] + + - samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2] + if self.scale_latent: samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") -- cgit v1.2.3 From e644b5a80beb54b6df4caa63fb19d889dd4ceff6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 17:03:03 +0300 Subject: remove scale latent and no-crop options from hires fix support copy-pasting new parameters for hires fix --- modules/processing.py | 64 ++++++++++++++++++++++----------------------------- modules/txt2img.py | 9 +++----- modules/ui.py | 19 ++++----------- 3 files changed, 35 insertions(+), 57 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index d9b0e0e7..100a259f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -506,14 +506,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): firstphase_width_truncated = 0 firstphase_height_truncated = 0 - def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, firstphase_width=512, firstphase_height=512, crop_scale=False, **kwargs): + def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=512, firstphase_height=512, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr - self.scale_latent = scale_latent self.denoising_strength = denoising_strength self.firstphase_width = firstphase_width self.firstphase_height = firstphase_height - self.crop_scale = crop_scale def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: @@ -530,6 +528,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) return samples + self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" + x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) @@ -538,46 +538,36 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): width_ratio = self.width/self.firstphase_width height_ratio = self.height/self.firstphase_height - if self.crop_scale: - if width_ratio > height_ratio: - #Crop to landscape - truncate_y = int((self.width - self.firstphase_width) / width_ratio / height_ratio / opt_f) + if width_ratio > height_ratio: + truncate_y = int((self.width - self.firstphase_width) / width_ratio / height_ratio / opt_f) - elif width_ratio < height_ratio: - #Crop to portrait - truncate_x = int((self.height - self.firstphase_height) / width_ratio / height_ratio / opt_f) + elif width_ratio < height_ratio: + truncate_x = int((self.height - self.firstphase_height) / width_ratio / height_ratio / opt_f) - samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2] - - + samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2] - + decoded_samples = decode_first_stage(self.sd_model, samples) - if self.scale_latent: - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": + decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") else: - decoded_samples = decode_first_stage(self.sd_model, samples) + lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) - if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": - decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") - else: - lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) - - batch_images = [] - for i, x_sample in enumerate(lowres_samples): - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - image = Image.fromarray(x_sample) - image = images.resize_image(0, image, self.width, self.height) - image = np.array(image).astype(np.float32) / 255.0 - image = np.moveaxis(image, 2, 0) - batch_images.append(image) - - decoded_samples = torch.from_numpy(np.array(batch_images)) - decoded_samples = decoded_samples.to(shared.device) - decoded_samples = 2. * decoded_samples - 1. - - samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) + batch_images = [] + for i, x_sample in enumerate(lowres_samples): + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + image = Image.fromarray(x_sample) + image = images.resize_image(0, image, self.width, self.height) + image = np.array(image).astype(np.float32) / 255.0 + image = np.moveaxis(image, 2, 0) + batch_images.append(image) + + decoded_samples = torch.from_numpy(np.array(batch_images)) + decoded_samples = decoded_samples.to(shared.device) + decoded_samples = 2. * decoded_samples - 1. + + samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) shared.state.nextjob() diff --git a/modules/txt2img.py b/modules/txt2img.py index 447ec3d3..2381347f 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -6,7 +6,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, crop_scale: bool, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -30,12 +30,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: restore_faces=restore_faces, tiling=tiling, enable_hr=enable_hr, - scale_latent=scale_latent if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None, - firstphase_width=firstphase_width if enable_hr else None, - firstphase_height=firstphase_height if enable_hr else None, - crop_scale=crop_scale if enable_hr else None, - + firstphase_width=firstphase_width if enable_hr else None, + firstphase_height=firstphase_height if enable_hr else None, ) if cmd_opts.enable_console_prompts: diff --git a/modules/ui.py b/modules/ui.py index f2d81f68..d66ddc14 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -540,16 +540,9 @@ def create_ui(wrap_gradio_gpu_call): enable_hr = gr.Checkbox(label='Highres. fix', value=False) with gr.Row(visible=False) as hr_options: - with gr.Column(scale=1.0): - firstphase_width = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass width", value=512) - firstphase_height = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass height", value=512) - - with gr.Column(scale=1.0): - with gr.Row(): - crop_scale = gr.Checkbox(label='Crop when scaling', value=False) - scale_latent = gr.Checkbox(label='Scale latent', value=False) - with gr.Row(): - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) + firstphase_width = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass width", value=512) + firstphase_height = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass height", value=512) + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) with gr.Row(equal_height=True): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) @@ -610,11 +603,9 @@ def create_ui(wrap_gradio_gpu_call): height, width, enable_hr, - scale_latent, denoising_strength, firstphase_width, firstphase_height, - crop_scale, ] + custom_inputs, outputs=[ txt2img_gallery, @@ -679,8 +670,8 @@ def create_ui(wrap_gradio_gpu_call): (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (firstphase_width, "First pass width"), - (firstphase_height, "First pass height"), + (firstphase_width, "First pass size-1"), + (firstphase_height, "First pass size-2"), ] modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) -- cgit v1.2.3 From 0aec19d7837d8564355fdb286541db7165852e41 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 18:15:03 +0300 Subject: make pasting into img2img prompt work make image params request not use temp files --- modules/images.py | 36 ++++++++++++++++++------------------ modules/ui.py | 9 +++------ 2 files changed, 21 insertions(+), 24 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index f1155b7f..68cdbc93 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,4 +1,5 @@ import datetime +import io import math import os from collections import namedtuple @@ -465,21 +466,20 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i return fullfn, txt_fullfn -def image_data(image_path): - file, ext = os.path.splitext(image_path.name) - data = {} - if "png" in ext: - image = Image.open(image_path.name, "r") - print(f"Image data requested for {image_path.name} {image.format} of {type(image)}") - try: - data = image.text["parameters"] - except Exception as e: - print(f"Exception: {e}") - pass - print(f"Image data: {data}") - if "txt" in ext: - myfile = open(image_path.name, 'r') - data = myfile.read() - myfile.close() - - return data, None +def image_data(data): + try: + image = Image.open(io.BytesIO(data)) + textinfo = image.text["parameters"] + return textinfo, None + except Exception: + pass + + try: + text = data.decode('utf8') + assert len(text) < 10000 + return text, None + + except Exception: + pass + + return '', None diff --git a/modules/ui.py b/modules/ui.py index 0a3ee887..6266db49 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -514,7 +514,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="file", visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): @@ -620,7 +620,6 @@ def create_ui(wrap_gradio_gpu_call): txt_prompt_img.change( fn=modules.images.image_data, - # _js = "get_extras_tab_index", inputs=[ txt_prompt_img ], @@ -692,8 +691,7 @@ def create_ui(wrap_gradio_gpu_call): img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="txt_prompt_image", file_count="single", type="file", - visible=False) + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) with gr.Column(scale=1): pass @@ -791,9 +789,8 @@ def create_ui(wrap_gradio_gpu_call): img2img_prompt_img.change( fn=modules.images.image_data, - # _js = "get_extras_tab_index", inputs=[ - txt_prompt_img + img2img_prompt_img ], outputs=[ img2img_prompt, -- cgit v1.2.3 From 67f447ddcc8a17d11939c3801dca635dc22944c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 19:30:28 +0300 Subject: possibility to load checkpoint, clip skip, and hypernet from infotext --- modules/ui.py | 52 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 6266db49..a37a4e17 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -22,7 +22,7 @@ import gradio as gr import gradio.utils import gradio.routes -from modules import sd_hijack +from modules import sd_hijack, sd_models from modules.paths import script_path from modules.shared import opts, cmd_opts if cmd_opts.deepdanbooru: @@ -507,12 +507,38 @@ def setup_progressbar(progressbar, preview, id_part, textinfo=None): ) +def apply_setting(key, value): + if value is None: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data[key] + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return value + + def create_ui(wrap_gradio_gpu_call): import modules.img2img import modules.txt2img with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -684,11 +710,10 @@ def create_ui(wrap_gradio_gpu_call): (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), ] - modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -938,7 +963,6 @@ def create_ui(wrap_gradio_gpu_call): (seed_resize_from_h, "Seed resize from-2"), (denoising_strength, "Denoising strength"), ] - modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as extras_interface: @@ -1580,8 +1604,22 @@ Requested path was: {f} outputs=[extras_image], ) - modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields, generation_info, 'switch_to_txt2img') - modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields, generation_info, 'switch_to_img2img_img2img') + settings_map = { + 'sd_hypernetwork': 'Hypernet', + 'CLIP_stop_at_last_layers': 'Clip skip', + 'sd_model_checkpoint': 'Model hash', + } + + settings_paste_fields = [ + (component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None))) + for k, v in settings_map.items() + ] + + modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt) + modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt) + + modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img') + modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img') ui_config_file = cmd_opts.ui_config_file ui_settings = {} -- cgit v1.2.3 From 2fb9891af3bb4c36a6de6b44937e927bda43c10d Mon Sep 17 00:00:00 2001 From: Gugubo <29143981+Gugubo@users.noreply.github.com> Date: Fri, 14 Oct 2022 14:19:39 +0200 Subject: Change grid row count autodetect to prevent empty spots Instead of just rounding (sometimes resulting in grids with "empty" spots), find a divisor. For example: 8 images will now result in a 4x2 grid instead of a 3x3 with one empty spot. --- modules/images.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 68cdbc93..90eca37a 100644 --- a/modules/images.py +++ b/modules/images.py @@ -25,8 +25,9 @@ def image_grid(imgs, batch_size=1, rows=None): elif opts.n_rows == 0: rows = batch_size else: - rows = math.sqrt(len(imgs)) - rows = round(rows) + rows = math.floor(math.sqrt(len(imgs))) + while len(imgs) % rows != 0: + rows -= 1 cols = math.ceil(len(imgs) / rows) -- cgit v1.2.3 From 43f926aad1b77a4bb642c1d173adfae1f56cf42d Mon Sep 17 00:00:00 2001 From: Gugubo <29143981+Gugubo@users.noreply.github.com> Date: Fri, 14 Oct 2022 17:06:51 +0200 Subject: Add option to prevent empty spots in grid (1/2) --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index b6a5c1a8..159f504f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -175,6 +175,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "grid_format": OptionInfo('png', 'File format for grids'), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), + "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), -- cgit v1.2.3 From 5f87dd1ee0960963e3f756c4ebe47652ff57f715 Mon Sep 17 00:00:00 2001 From: Gugubo <29143981+Gugubo@users.noreply.github.com> Date: Fri, 14 Oct 2022 17:07:24 +0200 Subject: Add option to prevent empty spots in grid (2/2) --- modules/images.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 90eca37a..b9589563 100644 --- a/modules/images.py +++ b/modules/images.py @@ -24,10 +24,13 @@ def image_grid(imgs, batch_size=1, rows=None): rows = opts.n_rows elif opts.n_rows == 0: rows = batch_size - else: + elif opts.grid_prevent_empty_spots: rows = math.floor(math.sqrt(len(imgs))) while len(imgs) % rows != 0: rows -= 1 + else: + rows = math.sqrt(len(imgs)) + rows = round(rows) cols = math.ceil(len(imgs) / rows) -- cgit v1.2.3 From a8eeb2b7ad0c43ad60ac2ba8bd299b9cb265fdd3 Mon Sep 17 00:00:00 2001 From: Ljzd-PRO <63289359+Ljzd-PRO@users.noreply.github.com> Date: Thu, 13 Oct 2022 02:03:08 +0800 Subject: add `--lowram` parameter load models to VRM instead of RAM (for machines which have bigger VRM than RAM such as free Google Colab server) --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 159f504f..cd4a4714 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -34,6 +34,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_ parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") +parser.add_argument("--lowram", action='store_true', help="load models to VRM instead of RAM (for machines which have bigger VRM than RAM such as free Google Colab server)") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") -- cgit v1.2.3 From 4a216ded433ded315106e2989c5ff7dec1c49304 Mon Sep 17 00:00:00 2001 From: Ljzd-PRO <63289359+Ljzd-PRO@users.noreply.github.com> Date: Thu, 13 Oct 2022 02:07:49 +0800 Subject: load models to VRAM when using `--lowram` param load models to VRM instead of RAM (for machines which have bigger VRM than RAM such as free Google Colab server) --- modules/sd_models.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a55b4c3..78a198b9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -134,7 +134,12 @@ def load_model_weights(model, checkpoint_info): print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location="cpu") + if shared.cmd_opts.lowram: + print("Load to VRAM if GPU is available (low RAM)") + pl_sd = torch.load(checkpoint_file) + else: + pl_sd = torch.load(checkpoint_file, map_location="cpu") + if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") @@ -158,7 +163,13 @@ def load_model_weights(model, checkpoint_info): if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") - vae_ckpt = torch.load(vae_file, map_location="cpu") + + if shared.cmd_opts.lowram: + print("Load to VRAM if GPU is available (low RAM)") + vae_ckpt = torch.load(vae_file) + else: + vae_ckpt = torch.load(vae_file, map_location="cpu") + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} model.first_stage_model.load_state_dict(vae_dict) -- cgit v1.2.3 From bb295f54785ac36dc6aa6f7103a3431464440fc3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 20:03:41 +0300 Subject: rework the code for lowram a bit --- modules/sd_models.py | 12 ++---------- modules/shared.py | 3 ++- 2 files changed, 4 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 78a198b9..3a01c93d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -134,11 +134,7 @@ def load_model_weights(model, checkpoint_info): print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - if shared.cmd_opts.lowram: - print("Load to VRAM if GPU is available (low RAM)") - pl_sd = torch.load(checkpoint_file) - else: - pl_sd = torch.load(checkpoint_file, map_location="cpu") + pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") @@ -164,11 +160,7 @@ def load_model_weights(model, checkpoint_info): if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") - if shared.cmd_opts.lowram: - print("Load to VRAM if GPU is available (low RAM)") - vae_ckpt = torch.load(vae_file) - else: - vae_ckpt = torch.load(vae_file, map_location="cpu") + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} diff --git a/modules/shared.py b/modules/shared.py index cd4a4714..695d29b6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -34,7 +34,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_ parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") -parser.add_argument("--lowram", action='store_true', help="load models to VRM instead of RAM (for machines which have bigger VRM than RAM such as free Google Colab server)") +parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") @@ -81,6 +81,7 @@ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.devic (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer']) device = devices.device +weight_load_location = None if cmd_opts.lowram else "cpu" batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram -- cgit v1.2.3 From c344ba3b325459abbf9b0df2c1b18f7bf99805b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 20:31:49 +0300 Subject: add option to read generation params for learning previews from txt2img --- modules/hypernetworks/hypernetwork.py | 21 ++++++++++++++++----- modules/textual_inversion/textual_inversion.py | 25 ++++++++++++++++++------- modules/ui.py | 20 +++++++++++++++++--- 3 files changed, 51 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index f1248bb7..e5cb1817 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -180,7 +180,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): return self.to_out(out) -def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): +def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert hypernetwork_name, 'hypernetwork not selected' path = shared.hypernetworks.get(hypernetwork_name, None) @@ -265,20 +265,31 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt - optimizer.zero_grad() shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, - prompt=preview_text, - steps=20, do_not_save_grid=True, do_not_save_samples=True, ) + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entry.cond_text + p.steps = 20 + + preview_text = p.prompt + processed = processing.process_images(p) image = processed.images[0] if len(processed.images)>0 else None diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fa0e33a2..3d835358 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -172,7 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -259,18 +259,29 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') - preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt - p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, - prompt=preview_text, - steps=20, - height=training_height, - width=training_width, do_not_save_grid=True, do_not_save_samples=True, ) + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entry.cond_text + p.steps = 20 + p.width = training_width + p.height = training_height + + preview_text = p.prompt + processed = processing.process_images(p) image = processed.images[0] diff --git a/modules/ui.py b/modules/ui.py index 828bfeea..4a04c2cc 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -711,6 +711,18 @@ def create_ui(wrap_gradio_gpu_call): (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), ] + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: @@ -1162,7 +1174,7 @@ def create_ui(wrap_gradio_gpu_call): create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) - preview_image_prompt = gr.Textbox(label='Preview prompt', value="") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) with gr.Row(): interrupt_training = gr.Button(value="Interrupt") @@ -1240,7 +1252,8 @@ def create_ui(wrap_gradio_gpu_call): save_embedding_every, template_file, save_image_with_stored_embedding, - preview_image_prompt, + preview_from_txt2img, + *txt2img_preview_params, ], outputs=[ ti_output, @@ -1260,7 +1273,8 @@ def create_ui(wrap_gradio_gpu_call): create_image_every, save_embedding_every, template_file, - preview_image_prompt, + preview_from_txt2img, + *txt2img_preview_params, ], outputs=[ ti_output, -- cgit v1.2.3 From 2f0e089c7c8e1ad7d2ad658971c6fdec9622e3ab Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 21:20:28 +0300 Subject: should fix the issue with missing layers in chechpoint merger --- modules/extras.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 532d869f..2e7b3751 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -209,7 +209,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: - theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key] if theta_2 else None, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + t2 = (theta_2 or {}).get(key) + if t2 is None: + t2 = torch.zeros_like(theta_0[key]) + + theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + if save_as_half: theta_0[key] = theta_0[key].half() -- cgit v1.2.3 From c250cb289c97fe303cef69064bf45899406f6a40 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 22:01:49 +0300 Subject: change checkpoint merger to work in a more obvious way remove sigmoid and inverse sigmoid because they just did the same thing as weighed sum only with changed multiplier --- modules/extras.py | 24 +++++------------------- modules/ui.py | 4 ++-- 2 files changed, 7 insertions(+), 21 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 2e7b3751..f2f5a7b0 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -159,24 +159,12 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, interp_amount, save_as_half, custom_name): - # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) +def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name): def weighted_sum(theta0, theta1, theta2, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) - # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def sigmoid(theta0, theta1, theta2, alpha): - alpha = alpha * alpha * (3 - (2 * alpha)) - return theta0 + ((theta1 - theta0) * alpha) - - # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def inv_sigmoid(theta0, theta1, theta2, alpha): - import math - alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) - return theta0 + ((theta1 - theta0) * alpha) - def add_difference(theta0, theta1, theta2, alpha): - return theta0 + (theta1 - theta2) * (1.0 - alpha) + return theta0 + (theta1 - theta2) * alpha primary_model_info = sd_models.checkpoints_list[primary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name] @@ -198,9 +186,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam theta_2 = None theta_funcs = { - "Weighted Sum": weighted_sum, - "Sigmoid": sigmoid, - "Inverse Sigmoid": inv_sigmoid, + "Weighted sum": weighted_sum, "Add difference": add_difference, } theta_func = theta_funcs[interp_method] @@ -213,7 +199,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam if t2 is None: t2 = torch.zeros_like(theta_0[key]) - theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier) if save_as_half: theta_0[key] = theta_0[key].half() @@ -227,7 +213,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' + filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' filename = filename if custom_name == '' else (custom_name + '.ckpt') output_modelname = os.path.join(ckpt_dir, filename) diff --git a/modules/ui.py b/modules/ui.py index 4a04c2cc..a08ffc9b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1101,8 +1101,8 @@ def create_ui(wrap_gradio_gpu_call): secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation amount (1 - M)', value=0.3) - interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid", "Add difference"], value="Weighted Sum", label="Interpolation Method") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Save as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') -- cgit v1.2.3 From 03d62538aebeff51713619fe808c953bdb70193d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 22:43:55 +0300 Subject: remove duplicate code for log loss, add step, make it read from options rather than gradio input --- modules/hypernetworks/hypernetwork.py | 20 ++++-------- modules/shared.py | 3 +- modules/textual_inversion/textual_inversion.py | 44 ++++++++++++++++++-------- modules/ui.py | 3 -- 4 files changed, 38 insertions(+), 32 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index edb8cba1..59c7ac6e 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -15,6 +15,7 @@ import torch from torch import einsum from einops import rearrange, repeat import modules.textual_inversion.dataset +from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -210,7 +211,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) @@ -263,19 +264,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') hypernetwork.save(last_saved_file) - if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0: - write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True - - with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout: - - csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"]) - - if write_csv_header: - csv_writer.writeheader() - - csv_writer.writerow({"step": hypernetwork.step, - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate}) + textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') diff --git a/modules/shared.py b/modules/shared.py index 695d29b6..d41a7ab3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,7 +236,8 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), - "training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), + "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), + "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 1f5ace6f..da0d77a0 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -173,6 +173,32 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn +def write_loss(log_directory, filename, step, epoch_len, values): + if shared.opts.training_write_csv_every == 0: + return + + if step % shared.opts.training_write_csv_every != 0: + return + + write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True + + with open(os.path.join(log_directory, filename), "a+", newline='') as fout: + csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) + + if write_csv_header: + csv_writer.writeheader() + + epoch = step // epoch_len + epoch_step = step - epoch * epoch_len + + csv_writer.writerow({ + "step": step + 1, + "epoch": epoch + 1, + "epoch_step": epoch_step + 1, + **values, + }) + + def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert embedding_name, 'embedding not selected' @@ -257,20 +283,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') embedding.save(last_saved_file) - if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0: - write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True - - with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout: - - csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss", "learn_rate"]) - - if write_csv_header: - csv_writer.writeheader() - - csv_writer.writerow({"epoch": epoch_num + 1, - "epoch_step": epoch_step - 1, - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate}) + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') diff --git a/modules/ui.py b/modules/ui.py index be4a43a7..a08ffc9b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1172,7 +1172,6 @@ def create_ui(wrap_gradio_gpu_call): training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) - write_csv_every = gr.Number(label='Save an csv containing the loss to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) @@ -1251,7 +1250,6 @@ def create_ui(wrap_gradio_gpu_call): steps, create_image_every, save_embedding_every, - write_csv_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, @@ -1274,7 +1272,6 @@ def create_ui(wrap_gradio_gpu_call): steps, create_image_every, save_embedding_every, - write_csv_every, template_file, preview_from_txt2img, *txt2img_preview_params, -- cgit v1.2.3 From e21f01f64504bc651da6e85216474bbd35ee010d Mon Sep 17 00:00:00 2001 From: Rae Fu Date: Thu, 13 Oct 2022 23:00:38 -0600 Subject: add checkpoint cache option to UI for faster model switching switching time reduced from ~1500ms to ~280ms --- modules/sd_models.py | 54 +++++++++++++++++++++++++++++++--------------------- modules/shared.py | 1 + 2 files changed, 33 insertions(+), 22 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a55b4c3..f3660d8d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,4 +1,4 @@ -import glob +import collections import os.path import sys from collections import namedtuple @@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir)) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) checkpoints_list = {} +checkpoints_loaded = collections.OrderedDict() try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -132,38 +133,46 @@ def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") + if checkpoint_info not in checkpoints_loaded: + print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") + pl_sd = torch.load(checkpoint_file, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") - sd = get_state_dict_from_checkpoint(pl_sd) + sd = get_state_dict_from_checkpoint(pl_sd) + model.load_state_dict(sd, strict=False) - model.load_state_dict(sd, strict=False) + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) + if not shared.cmd_opts.no_half: + model.half() - if not shared.cmd_opts.no_half: - model.half() + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: + vae_file = shared.cmd_opts.vae_path - if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: - vae_file = shared.cmd_opts.vae_path + if os.path.exists(vae_file): + print(f"Loading VAE weights from: {vae_file}") + vae_ckpt = torch.load(vae_file, map_location="cpu") + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} - if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") - vae_ckpt = torch.load(vae_file, map_location="cpu") - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + model.first_stage_model.load_state_dict(vae_dict) - model.first_stage_model.load_state_dict(vae_dict) + model.first_stage_model.to(devices.dtype_vae) - model.first_stage_model.to(devices.dtype_vae) + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU + else: + print(f"Loading weights [{sd_model_hash}] from cache") + checkpoints_loaded.move_to_end(checkpoint_info) + model.load_state_dict(checkpoints_loaded[checkpoint_info]) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file @@ -202,6 +211,7 @@ def reload_model_weights(sd_model, info=None): return if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + checkpoints_loaded.clear() shared.sd_model = load_model() return shared.sd_model diff --git a/modules/shared.py b/modules/shared.py index 5901e605..b2090da1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -238,6 +238,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), -- cgit v1.2.3 From cd58e44051f658f2efb544203a92837f43786372 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 23:17:28 +0300 Subject: disabling history - i knew it was slow as fuck but i didn't realize it would also show galleries on launch --- modules/ui.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a08ffc9b..6d193955 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1089,7 +1089,8 @@ def create_ui(wrap_gradio_gpu_call): "t2i":txt2img_paste_fields, "i2i":img2img_paste_fields } - images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + + #images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): @@ -1486,7 +1487,7 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (images_history, "History", "images_history"), + #(images_history, "History", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), (settings_interface, "Settings", "settings"), -- cgit v1.2.3 From 368f4cc4c73509c1968cd9defe068d8bf4ff7c4f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 23:19:05 +0300 Subject: set firstpass w/h to 0 by default and rever to old behavior when any are 0 --- modules/processing.py | 49 ++++++++++++++++++++++++++++++------------------- modules/ui.py | 4 ++-- 2 files changed, 32 insertions(+), 21 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 100a259f..a75b9f84 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -501,17 +501,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - firstphase_width = 0 - firstphase_height = 0 - firstphase_width_truncated = 0 - firstphase_height_truncated = 0 - def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=512, firstphase_height=512, **kwargs): + def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength self.firstphase_width = firstphase_width self.firstphase_height = firstphase_height + self.truncate_x = 0 + self.truncate_y = 0 def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: @@ -520,6 +518,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 + if self.firstphase_width == 0 or self.firstphase_height == 0: + desired_pixel_count = 512 * 512 + actual_pixel_count = self.width * self.height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + self.firstphase_width = math.ceil(scale * self.width / 64) * 64 + self.firstphase_height = math.ceil(scale * self.height / 64) * 64 + firstphase_width_truncated = int(scale * self.width) + firstphase_height_truncated = int(scale * self.height) + + else: + self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" + + width_ratio = self.width / self.firstphase_width + height_ratio = self.height / self.firstphase_height + + if width_ratio > height_ratio: + firstphase_width_truncated = self.firstphase_width + firstphase_height_truncated = self.firstphase_width * self.height / self.width + else: + firstphase_width_truncated = self.firstphase_height * self.width / self.height + firstphase_height_truncated = self.firstphase_height + + self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f + self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f + + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) @@ -528,23 +552,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) return samples - self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" - x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) - truncate_x = 0 - truncate_y = 0 - width_ratio = self.width/self.firstphase_width - height_ratio = self.height/self.firstphase_height - - if width_ratio > height_ratio: - truncate_y = int((self.width - self.firstphase_width) / width_ratio / height_ratio / opt_f) - - elif width_ratio < height_ratio: - truncate_x = int((self.height - self.firstphase_height) / width_ratio / height_ratio / opt_f) - - samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2] + samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] decoded_samples = decode_first_stage(self.sd_model, samples) diff --git a/modules/ui.py b/modules/ui.py index 6d193955..a1d18be9 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -567,8 +567,8 @@ def create_ui(wrap_gradio_gpu_call): enable_hr = gr.Checkbox(label='Highres. fix', value=False) with gr.Row(visible=False) as hr_options: - firstphase_width = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass width", value=512) - firstphase_height = gr.Slider(minimum=64, maximum=1024, step=64, label="First pass height", value=512) + firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="First pass width", value=0) + firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="First pass height", value=0) denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) with gr.Row(equal_height=True): -- cgit v1.2.3 From 4d19f3b7d461fe0f63e7ccff936909b0ce0c6126 Mon Sep 17 00:00:00 2001 From: Melan Date: Fri, 14 Oct 2022 22:45:26 +0200 Subject: Raise an assertion error if no training images have been found. --- modules/textual_inversion/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 67e90afe..12e2f43b 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -81,7 +81,8 @@ class PersonalizedBase(Dataset): entry.cond = cond_model([entry.cond_text]).to(devices.cpu) self.dataset.append(entry) - + + assert len(self.dataset) > 1, "No images have been found in the dataset." self.length = len(self.dataset) * repeats self.initial_indexes = np.arange(self.length) % len(self.dataset) -- cgit v1.2.3 From 4dc426509918e90bf4557ecfd1f84031362360c0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 00:21:48 +0300 Subject: rename firstpass w/h to discard old user settings --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a1d18be9..c5d295ea 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -567,8 +567,8 @@ def create_ui(wrap_gradio_gpu_call): enable_hr = gr.Checkbox(label='Highres. fix', value=False) with gr.Row(visible=False) as hr_options: - firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="First pass width", value=0) - firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="First pass height", value=0) + firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0) + firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0) denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) with gr.Row(equal_height=True): -- cgit v1.2.3 From 4bbe5d62e042e78cfe1dc83492c2398a39a2455c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 00:25:09 +0300 Subject: reformat lines in images_history.py --- modules/images_history.py | 182 +++++++++++++++++++++++++--------------------- 1 file changed, 98 insertions(+), 84 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 723f5301..f5ef44fe 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -1,5 +1,7 @@ import os import shutil + + def traverse_all_files(output_dir, image_list, curr_dir=None): curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir) try: @@ -16,10 +18,10 @@ def traverse_all_files(output_dir, image_list, curr_dir=None): elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0: image_list.append(file) else: - image_list = traverse_all_files(output_dir, image_list, file) + image_list = traverse_all_files(output_dir, image_list, file) return image_list - + def get_recent_images(dir_name, page_index, step, image_index, tabname): page_index = int(page_index) f_list = os.listdir(dir_name) @@ -27,36 +29,48 @@ def get_recent_images(dir_name, page_index, step, image_index, tabname): image_list = traverse_all_files(dir_name, image_list) image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) num = 48 if tabname != "extras" else 12 - max_page_index = len(image_list) // num + 1 + max_page_index = len(image_list) // num + 1 page_index = max_page_index if page_index == -1 else page_index + step - page_index = 1 if page_index < 1 else page_index + page_index = 1 if page_index < 1 else page_index page_index = max_page_index if page_index > max_page_index else page_index idx_frm = (page_index - 1) * num image_list = image_list[idx_frm:idx_frm + num] image_index = int(image_index) - if image_index < 0 or image_index > len(image_list) - 1: - current_file = None + if image_index < 0 or image_index > len(image_list) - 1: + current_file = None hidden = None else: - current_file = image_list[int(image_index)] + current_file = image_list[int(image_index)] hidden = os.path.join(dir_name, current_file) return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, "" + def first_page_click(dir_name, page_index, image_index, tabname): return get_recent_images(dir_name, 1, 0, image_index, tabname) + + def end_page_click(dir_name, page_index, image_index, tabname): return get_recent_images(dir_name, -1, 0, image_index, tabname) + + def prev_page_click(dir_name, page_index, image_index, tabname): return get_recent_images(dir_name, page_index, -1, image_index, tabname) -def next_page_click(dir_name, page_index, image_index, tabname): + + +def next_page_click(dir_name, page_index, image_index, tabname): return get_recent_images(dir_name, page_index, 1, image_index, tabname) -def page_index_change(dir_name, page_index, image_index, tabname): + + +def page_index_change(dir_name, page_index, image_index, tabname): return get_recent_images(dir_name, page_index, 0, image_index, tabname) + def show_image_info(num, image_path, filenames): - #print(f"select image {num}") + # print(f"select image {num}") file = filenames[int(num)] return file, num, os.path.join(image_path, file) + + def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index): if name == "": return filenames, delete_num @@ -66,14 +80,14 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima i = 0 new_file_list = [] for name in filenames: - if i >= index and i < index + delete_num: + if i >= index and i < index + delete_num: path = os.path.join(dir_name, name) - if os.path.exists(path): + if os.path.exists(path): print(f"Delete file {path}") os.remove(path) - txt_file = os.path.splitext(path)[0] + ".txt" + txt_file = os.path.splitext(path)[0] + ".txt" if os.path.exists(txt_file): - os.remove(txt_file) + os.remove(txt_file) else: print(f"Not exists file {path}") else: @@ -81,81 +95,81 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima i += 1 return new_file_list, 1 + def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): - if tabname == "txt2img": - dir_name = opts.outdir_txt2img_samples - elif tabname == "img2img": - dir_name = opts.outdir_img2img_samples - elif tabname == "extras": - dir_name = opts.outdir_extras_samples - d = dir_name.split("/") - dir_name = d[0] - for p in d[1:]: - dir_name = os.path.join(dir_name, p) - with gr.Row(): - renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page") - first_page = gr.Button('First Page') - prev_page = gr.Button('Prev Page') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next Page') - end_page = gr.Button('End Page') - with gr.Row(elem_id=tabname + "_images_history"): - with gr.Row(): - with gr.Column(scale=2): - history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) - with gr.Row(): - delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") - delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - with gr.Column(): - with gr.Row(): - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') - with gr.Row(): - with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info", interactive=False) - img_file_name = gr.Textbox(label="File Name", interactive=False) - with gr.Row(): - # hiden items - - img_path = gr.Textbox(dir_name.rstrip("/") , visible=False) - tabname_box = gr.Textbox(tabname, visible=False) - image_index = gr.Textbox(value=-1, visible=False) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) - filenames = gr.State() - hidden = gr.Image(type="pil", visible=False) - info1 = gr.Textbox(visible=False) - info2 = gr.Textbox(visible=False) - - - # turn pages - gallery_inputs = [img_path, page_index, image_index, tabname_box] - gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name] - - first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - #page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index]) - - #other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden]) - img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) - delete.click(delete_image,_js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) - hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) - - #pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) - switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') - switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') - - + if tabname == "txt2img": + dir_name = opts.outdir_txt2img_samples + elif tabname == "img2img": + dir_name = opts.outdir_img2img_samples + elif tabname == "extras": + dir_name = opts.outdir_extras_samples + d = dir_name.split("/") + dir_name = d[0] + for p in d[1:]: + dir_name = os.path.join(dir_name, p) + with gr.Row(): + renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page") + first_page = gr.Button('First Page') + prev_page = gr.Button('Prev Page') + page_index = gr.Number(value=1, label="Page Index") + next_page = gr.Button('Next Page') + end_page = gr.Button('End Page') + with gr.Row(elem_id=tabname + "_images_history"): + with gr.Row(): + with gr.Column(scale=2): + history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) + with gr.Row(): + delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") + delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") + with gr.Column(): + with gr.Row(): + pnginfo_send_to_txt2img = gr.Button('Send to txt2img') + pnginfo_send_to_img2img = gr.Button('Send to img2img') + with gr.Row(): + with gr.Column(): + img_file_info = gr.Textbox(label="Generate Info", interactive=False) + img_file_name = gr.Textbox(label="File Name", interactive=False) + with gr.Row(): + # hiden items + + img_path = gr.Textbox(dir_name.rstrip("/"), visible=False) + tabname_box = gr.Textbox(tabname, visible=False) + image_index = gr.Textbox(value=-1, visible=False) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) + filenames = gr.State() + hidden = gr.Image(type="pil", visible=False) + info1 = gr.Textbox(visible=False) + info2 = gr.Textbox(visible=False) + + # turn pages + gallery_inputs = [img_path, page_index, image_index, tabname_box] + gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name] + + first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index]) + + # other funcitons + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden]) + img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) + delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) + hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) + + # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) + switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') + switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') + + def create_history_tabs(gr, opts, run_pnginfo, switch_dict): with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: with gr.Tab("txt2img history"): - with gr.Blocks(analytics_enabled=False) as images_history_txt2img: - show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict) + with gr.Blocks(analytics_enabled=False) as images_history_txt2img: + show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict) with gr.Tab("img2img history"): with gr.Blocks(analytics_enabled=False) as images_history_img2img: show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict) -- cgit v1.2.3 From acedbe67d2b8a3af99ca3b9a2f809e7a2db285d1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 00:43:15 +0300 Subject: bring history tab back, make it behave; it's still slow but won't fuck anything up until you use it --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index c5d295ea..1bc919c7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1090,7 +1090,7 @@ def create_ui(wrap_gradio_gpu_call): "i2i":img2img_paste_fields } - #images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): @@ -1487,7 +1487,7 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - #(images_history, "History", "images_history"), + (images_history, "History", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), (settings_interface, "Settings", "settings"), -- cgit v1.2.3 From c7a86f7fe9c0b8967a87e8d709f507d2f44400d8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 09:24:59 +0300 Subject: add option to use batch size for training --- modules/hypernetworks/hypernetwork.py | 33 +++++++++++++++++++------- modules/textual_inversion/dataset.py | 31 ++++++++++++++---------- modules/textual_inversion/textual_inversion.py | 17 +++++++------ modules/ui.py | 3 +++ 4 files changed, 54 insertions(+), 30 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 59c7ac6e..a2b3bc0a 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -182,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): return self.to_out(out) -def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def stack_conds(conds): + if len(conds) == 1: + return torch.stack(conds) + + # same as in reconstruct_multicond_batch + token_count = max([x.shape[0] for x in conds]) + for i in range(len(conds)): + if conds[i].shape[0] != token_count: + last_vector = conds[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1]) + conds[i] = torch.vstack([conds[i], last_vector_repeated]) + + return torch.stack(conds) + +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert hypernetwork_name, 'hypernetwork not selected' path = shared.hypernetworks.get(hypernetwork_name, None) @@ -211,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) @@ -235,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, entry in pbar: + for i, entries in pbar: hypernetwork.step = i + ititial_step scheduler.apply(optimizer, hypernetwork.step) @@ -246,11 +260,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, break with torch.autocast("cuda"): - cond = entry.cond.to(devices.device) - x = entry.latent.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), cond)[0] + c = stack_conds([entry.cond for entry in entries]).to(devices.device) +# c = torch.vstack([entry.cond for entry in entries]).to(devices.device) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] del x - del cond + del c losses[hypernetwork.step % losses.shape[0]] = loss.item() @@ -292,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, p.width = preview_width p.height = preview_height else: - p.prompt = entry.cond_text + p.prompt = entries[0].cond_text p.steps = 20 preview_text = p.prompt @@ -315,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,

Loss: {losses.mean():.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entries[0].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 67e90afe..bd99c0cb 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -24,11 +24,12 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): - re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): + re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token + self.batch_size = batch_size self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) @@ -78,13 +79,13 @@ class PersonalizedBase(Dataset): if include_cond: entry.cond_text = self.create_text(filename_text) - entry.cond = cond_model([entry.cond_text]).to(devices.cpu) + entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) self.dataset.append(entry) - self.length = len(self.dataset) * repeats + self.length = len(self.dataset) * repeats // batch_size - self.initial_indexes = np.arange(self.length) % len(self.dataset) + self.initial_indexes = np.arange(len(self.dataset)) self.indexes = None self.shuffle() @@ -101,13 +102,19 @@ class PersonalizedBase(Dataset): return self.length def __getitem__(self, i): - if i % len(self.dataset) == 0: - self.shuffle() + res = [] - index = self.indexes[i % len(self.indexes)] - entry = self.dataset[index] + for j in range(self.batch_size): + position = i * self.batch_size + j + if position % len(self.indexes) == 0: + self.shuffle() - if entry.cond is None: - entry.cond_text = self.create_text(entry.filename_text) + index = self.indexes[position % len(self.indexes)] + entry = self.dataset[index] - return entry + if entry.cond is None: + entry.cond_text = self.create_text(entry.filename_text) + + res.append(entry) + + return res diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index da0d77a0..e754747e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -199,7 +199,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -231,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) hijack = sd_hijack.model_hijack @@ -251,7 +251,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, entry in pbar: + for i, entries in pbar: embedding.step = i + ititial_step scheduler.apply(optimizer, embedding.step) @@ -262,10 +262,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini break with torch.autocast("cuda"): - c = cond_model([entry.cond_text]) - - x = entry.latent.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] + c = cond_model([entry.cond_text for entry in entries]) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] del x losses[embedding.step % losses.shape[0]] = loss.item() @@ -307,7 +306,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini p.width = preview_width p.height = preview_height else: - p.prompt = entry.cond_text + p.prompt = entries[0].cond_text p.steps = 20 p.width = training_width p.height = training_height @@ -348,7 +347,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini

Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entries[0].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/ui.py b/modules/ui.py index 1bc919c7..45550ea8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1166,6 +1166,7 @@ def create_ui(wrap_gradio_gpu_call): train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") + batch_size = gr.Number(label='Batch size', value=1, precision=0) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) @@ -1244,6 +1245,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ train_embedding_name, learn_rate, + batch_size, dataset_directory, log_directory, training_width, @@ -1268,6 +1270,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ train_hypernetwork_name, learn_rate, + batch_size, dataset_directory, log_directory, steps, -- cgit v1.2.3 From db27b987a97fc8b7894a9dd34bd7641536f9c424 Mon Sep 17 00:00:00 2001 From: aoirusann Date: Sat, 15 Oct 2022 11:48:13 +0800 Subject: Add hint for `ctrl/alt enter` And duplicate implementations are removed --- modules/ui.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 45550ea8..baf4c397 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -433,7 +433,10 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2) + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + with gr.Column(scale=1, elem_id="roll_col"): roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) paste = gr.Button(value=paste_symbol, elem_id="paste") @@ -446,7 +449,10 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=8): with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + with gr.Column(scale=1, elem_id="roll_col"): sh = gr.Button(elem_id="sh", visible=True) -- cgit v1.2.3 From cd28465bf87d911965790513c37e6881e4231523 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Sat, 15 Oct 2022 10:56:02 +0900 Subject: do not force relative paths in image history --- modules/images_history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index f5ef44fe..09c749fe 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -104,7 +104,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): elif tabname == "extras": dir_name = opts.outdir_extras_samples d = dir_name.split("/") - dir_name = d[0] + dir_name = "/" if dir_name.startswith("/") else d[0] for p in d[1:]: dir_name = os.path.join(dir_name, p) with gr.Row(): -- cgit v1.2.3 From 0da6c1809996f0f696d4047faf4b9c9939e26daa Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Sat, 15 Oct 2022 11:22:05 +0900 Subject: use "outdir_samples" if specified --- modules/images_history.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 09c749fe..9260df8a 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -97,7 +97,9 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): - if tabname == "txt2img": + if opts.outdir_samples != "": + dir_name = opts.outdir_samples + elif tabname == "txt2img": dir_name = opts.outdir_txt2img_samples elif tabname == "img2img": dir_name = opts.outdir_img2img_samples -- cgit v1.2.3 From a13af34b902bebc5df9509228380206a01f1245b Mon Sep 17 00:00:00 2001 From: githublsx Date: Thu, 13 Oct 2022 20:05:07 -0700 Subject: Set to -1 when seed input is none --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index a75b9f84..7e2a416d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -140,7 +140,7 @@ class Processed: self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] - self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) + self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1 self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.all_prompts = all_prompts or [self.prompt] -- cgit v1.2.3 From c24df4b486a48c60f48276f7760a9acb4a13e22d Mon Sep 17 00:00:00 2001 From: CookieHCl Date: Sat, 15 Oct 2022 03:26:36 +0900 Subject: Disable compiling deepbooru model This is only necessary when you have to train, and compiling model produces warning. --- modules/deepbooru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index f34f3788..4ad334a1 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -102,7 +102,7 @@ def get_deepbooru_tags_model(): tags = dd.project.load_tags_from_project(model_path) model = dd.project.load_model_from_project( - model_path, compile_model=True + model_path, compile_model=False ) return model, tags -- cgit v1.2.3 From e8729dd0511f8410db967d9ef192645cfef1be8a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 12:54:23 +0300 Subject: re-apply height hacks to work with new gradio --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index baf4c397..9c7a67dd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -750,10 +750,10 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: with gr.TabItem('img2img', id='img2img'): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool) + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480) with gr.TabItem('Inpaint', id='inpaint'): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA") + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") -- cgit v1.2.3 From 5967d07d1aa4e2fef031a57b1612b1ab04a3cd78 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 13:11:28 +0300 Subject: fix new gradio failing to preserve image params --- modules/ui.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 9c7a67dd..de5ab929 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -7,6 +7,7 @@ import mimetypes import os import random import sys +import tempfile import time import traceback import platform @@ -176,6 +177,23 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") +def save_pil_to_file(pil_image, dir=None): + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in pil_image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) + pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) + return file_obj + + +# override save to file function so that it also writes PNG info +gr.processing_utils.save_pil_to_file = save_pil_to_file + + def wrap_gradio_call(func, extra_outputs=None): def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled -- cgit v1.2.3 From f7ca63937ac83d32483285c3af09afaa356d6276 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 13:23:12 +0300 Subject: bring back scale latent option in settings --- modules/processing.py | 8 ++++---- modules/shared.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 7e2a416d..b9a1660e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -557,11 +557,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] - decoded_samples = decode_first_stage(self.sd_model, samples) + if opts.use_scale_latent_for_hires_fix: + samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") - if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": - decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") else: + decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) batch_images = [] @@ -578,7 +578,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): decoded_samples = decoded_samples.to(shared.device) decoded_samples = 2. * decoded_samples - 1. - samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) + samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) shared.state.nextjob() diff --git a/modules/shared.py b/modules/shared.py index aa69bedf..b4141e67 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -218,6 +218,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), + "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space iamge when doing hires. fix"), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { @@ -256,7 +257,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), - 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { @@ -284,6 +284,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From d3463bc59a44d62c2de8b357184c49876d84f654 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 14:22:30 +0300 Subject: change styling for top right corner UI made save style button not die when you cancel --- modules/ui.py | 57 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 28 insertions(+), 29 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index de5ab929..cab8ab11 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -81,6 +81,8 @@ art_symbol = '\U0001f3a8' # 🎨 paste_symbol = '\u2199\ufe0f' # ↙ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 def plaintext_to_html(text): @@ -322,7 +324,7 @@ def visit(x, func, path=""): def add_style(name: str, prompt: str, negative_prompt: str): if name is None: - return [gr_show(), gr_show()] + return [gr_show() for x in range(4)] style = modules.styles.PromptStyle(name, prompt, negative_prompt) shared.prompt_styles.styles[style.name] = style @@ -447,7 +449,7 @@ def create_toprow(is_img2img): id_part = "img2img" if is_img2img else "txt2img" with gr.Row(elem_id="toprow"): - with gr.Column(scale=4): + with gr.Column(scale=6): with gr.Row(): with gr.Column(scale=80): with gr.Row(): @@ -455,27 +457,30 @@ def create_toprow(is_img2img): placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" ) - with gr.Column(scale=1, elem_id="roll_col"): - roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) - paste = gr.Button(value=paste_symbol, elem_id="paste") - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - with gr.Column(scale=10, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) - with gr.Row(): - with gr.Column(scale=8): + with gr.Column(scale=80): with gr.Row(): negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" ) - with gr.Column(scale=1, elem_id="roll_col"): - sh = gr.Button(elem_id="sh", visible=True) + with gr.Column(scale=1, elem_id="roll_col"): + roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) + paste = gr.Button(value=paste_symbol, elem_id="paste") + save_style = gr.Button(value=save_style_symbol, elem_id="style_create") + prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + + button_interrogate = None + button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_id="interrogate_col"): + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + + if cmd_opts.deepdanbooru: + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") with gr.Column(scale=1): with gr.Row(): @@ -495,20 +500,14 @@ def create_toprow(is_img2img): outputs=[], ) - with gr.Row(scale=1): - if is_img2img: - interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - if cmd_opts.deepdanbooru: - deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - else: - deepbooru = None - else: - interrogate = None - deepbooru = None - prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") - save_style = gr.Button('Create style', elem_id="style_create") + with gr.Row(): + with gr.Column(scale=1, elem_id="style_pos_col"): + prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + with gr.Column(scale=1, elem_id="style_neg_col"): + prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button def setup_progressbar(progressbar, preview, id_part, textinfo=None): -- cgit v1.2.3 From 20a1f68c752f8e37492ea00911c97bfc572a6e67 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 15:44:46 +0300 Subject: fix gadio issue with sending files between tabs --- modules/ui.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index cab8ab11..c9b53247 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -91,6 +91,14 @@ def plaintext_to_html(text): def image_from_url_text(filedata): + if type(filedata) == dict and filedata["is_file"]: + filename = filedata["name"] + tempdir = os.path.normpath(tempfile.gettempdir()) + normfn = os.path.normpath(filename) + assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory' + + return Image.open(filename) + if type(filedata) == list: if len(filedata) == 0: return None -- cgit v1.2.3 From 97f0727489ddd3d7ca264c54ed0f63b6847502e2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 15:47:02 +0300 Subject: add First pass size always regardless of whether it was auto chosen or specified --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index b9a1660e..941ae089 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -528,7 +528,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): firstphase_height_truncated = int(scale * self.height) else: - self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" width_ratio = self.width / self.firstphase_width height_ratio = self.height / self.firstphase_height @@ -540,6 +539,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): firstphase_width_truncated = self.firstphase_height * self.width / self.height firstphase_height_truncated = self.firstphase_height + self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f -- cgit v1.2.3 From eef3bc649069d6caaef1274f132de28e528bfa7d Mon Sep 17 00:00:00 2001 From: NO_ob <15161159+NO-ob@users.noreply.github.com> Date: Sat, 15 Oct 2022 11:43:30 +0100 Subject: typo --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index b4141e67..fa30bbb0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -218,7 +218,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), - "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space iamge when doing hires. fix"), + "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { -- cgit v1.2.3 From 37d7ffb415cd8c69b3c0bb5f61844dde0b169f78 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 15:59:37 +0200 Subject: fix to tokens lenght, addend embs generator, add new features to edit the embedding before the generation using text --- modules/aesthetic_clip.py | 78 ++++++++++++++++++++++++ modules/processing.py | 148 +++++++++++++++++++++++++++++++--------------- modules/sd_hijack.py | 111 ++++++++++++++++++++++------------ modules/shared.py | 4 ++ modules/txt2img.py | 10 +++- modules/ui.py | 47 ++++++++++++--- 6 files changed, 302 insertions(+), 96 deletions(-) create mode 100644 modules/aesthetic_clip.py (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py new file mode 100644 index 00000000..f15cfd47 --- /dev/null +++ b/modules/aesthetic_clip.py @@ -0,0 +1,78 @@ +import itertools +import os +from pathlib import Path +import html +import gc + +import gradio as gr +import torch +from PIL import Image +from modules import shared +from modules.shared import device, aesthetic_embeddings +from transformers import CLIPModel, CLIPProcessor + +from tqdm.auto import tqdm + + +def get_all_images_in_folder(folder): + return [os.path.join(folder, f) for f in os.listdir(folder) if + os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)] + + +def check_is_valid_image_file(filename): + return filename.lower().endswith(('.png', '.jpg', '.jpeg')) + + +def batched(dataset, total, n=1): + for ndx in range(0, total, n): + yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] + + +def iter_to_batched(iterable, n=1): + it = iter(iterable) + while True: + chunk = tuple(itertools.islice(it, n)) + if not chunk: + return + yield chunk + + +def generate_imgs_embd(name, folder, batch_size): + # clipModel = CLIPModel.from_pretrained( + # shared.sd_model.cond_stage_model.clipModel.name_or_path + # ) + model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device) + processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path) + + with torch.no_grad(): + embs = [] + for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size), + desc=f"Generating embeddings for {name}"): + if shared.state.interrupted: + break + inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device) + outputs = model.get_image_features(**inputs).cpu() + embs.append(torch.clone(outputs)) + inputs.to("cpu") + del inputs, outputs + + embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) + + # The generated embedding will be located here + path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") + torch.save(embs, path) + + model = model.cpu() + del model + del processor + del embs + gc.collect() + torch.cuda.empty_cache() + res = f""" + Done generating embedding for {name}! + Hypernetwork saved to {html.escape(path)} + """ + shared.update_aesthetic_embeddings() + return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", + value=sorted(aesthetic_embeddings.keys())[0] if len( + aesthetic_embeddings) > 0 else None), res, "" diff --git a/modules/processing.py b/modules/processing.py index 9a033759..ab68d63a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -20,7 +20,6 @@ import modules.images as images import modules.styles import logging - # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 opt_f = 8 @@ -52,8 +51,13 @@ def get_correct_sampler(p): elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): return sd_samplers.samplers_for_img2img + class StableDiffusionProcessing: - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, + subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, + sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, + restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, + extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -104,7 +108,8 @@ class StableDiffusionProcessing: class Processed: - def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): + def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, + all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): self.images = images_list self.prompt = p.prompt self.negative_prompt = p.negative_prompt @@ -141,7 +146,8 @@ class Processed: self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) - self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 + self.subseed = int( + self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.all_prompts = all_prompts or [self.prompt] self.all_seeds = all_seeds or [self.seed] @@ -181,39 +187,43 @@ class Processed: return json.dumps(obj) - def infotext(self, p: StableDiffusionProcessing, index): - return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size) + def infotext(self, p: StableDiffusionProcessing, index): + return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], + position_in_batch=index % self.batch_size, iteration=index // self.batch_size) # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 def slerp(val, low, high): - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) - dot = (low_norm*high_norm).sum(1) + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + dot = (low_norm * high_norm).sum(1) if dot.mean() > 0.9995: return low * val + high * (1 - val) omega = torch.acos(dot) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res -def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): +def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, + p=None): xs = [] # if we have multiple seeds, this means we are working with batch size>1; this then # enables the generation of additional tensors with noise that the sampler will use during its processing. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): + if p is not None and p.sampler is not None and ( + len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] else: sampler_noises = None for i, seed in enumerate(seeds): - noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) + noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else ( + shape[0], seed_resize_from_h // 8, seed_resize_from_w // 8) subnoise = None if subseeds is not None: @@ -241,7 +251,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see dx = max(-dx, 0) dy = max(-dy, 0) - x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w] + x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w] noise = x if sampler_noises is not None: @@ -293,14 +303,20 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", - "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')), + "Model hash": getattr(p, 'sd_model_hash', + None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), + "Model": ( + None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace( + ',', '').replace(':', '')), + "Hypernet": ( + None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace( + ':', '')), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), - "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), + "Seed resize from": ( + None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, @@ -309,7 +325,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params.update(p.extra_generation_params) - generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join( + [k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" @@ -317,7 +334,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, - aesthetic_imgs=None,aesthetic_slerp=False) -> Processed: + aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" aesthetic_lr = float(aesthetic_lr) @@ -385,7 +404,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh for n in range(p.n_iter): if state.skipped: state.skipped = False - + if state.interrupted: break @@ -396,16 +415,19 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh if (len(prompts) == 0): break - #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) - #c = p.sd_model.get_learned_conditioning(prompts) + # uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) + # c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): - shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0) + shared.sd_model.cond_stage_model.set_aesthetic_params() uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight, - aesthetic_steps, aesthetic_imgs,aesthetic_slerp) + aesthetic_steps, aesthetic_imgs, + aesthetic_slerp, aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: @@ -413,13 +435,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh comments[comment] = 1 if p.n_iter > 1: - shared.state.job = f"Batch {n+1} out of {p.n_iter}" + shared.state.job = f"Batch {n + 1} out of {p.n_iter}" with devices.autocast(): - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, + subseed_strength=p.subseed_strength) if state.interrupted or state.skipped: - # if we are interrupted, sample returns just noise # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent @@ -445,7 +467,9 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh if p.restore_faces: if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: - images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") + images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], + opts.samples_format, info=infotext(n, i), p=p, + suffix="-before-face-restoration") devices.torch_gc() @@ -456,7 +480,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: - images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") + images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, + info=infotext(n, i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) if p.overlay_images is not None and i < len(p.overlay_images): @@ -474,7 +499,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh image = image.convert('RGB') if opts.samples_save and not p.do_not_save_samples: - images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) + images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, + info=infotext(n, i), p=p) text = infotext(n, i) infotexts.append(text) @@ -482,7 +508,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh image.info["parameters"] = text output_images.append(image) - del x_samples_ddim + del x_samples_ddim devices.torch_gc() @@ -504,10 +530,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh index_of_first_image = 1 if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, + info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) devices.torch_gc() - return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), + subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, + index_of_first_image=index_of_first_image, infotexts=infotexts) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): @@ -543,25 +572,34 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, + subseeds=subseeds, subseed_strength=self.subseed_strength, + seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) return samples - x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, + subseeds=subseeds, subseed_strength=self.subseed_strength, + seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f - samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2] + samples = samples[:, :, truncate_y // 2:samples.shape[2] - truncate_y // 2, + truncate_x // 2:samples.shape[3] - truncate_x // 2] if self.scale_latent: - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), + mode="bilinear") else: decoded_samples = decode_first_stage(self.sd_model, samples) if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": - decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") + decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), + mode="bilinear") else: lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) @@ -585,13 +623,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) - noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, + subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, + steps=self.steps) return samples @@ -599,7 +640,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None - def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs): + def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, + inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, + **kwargs): super().__init__(**kwargs) self.init_images = init_images @@ -607,7 +650,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.denoising_strength: float = denoising_strength self.init_latent = None self.image_mask = mask - #self.image_unblurred_mask = None + # self.image_unblurred_mask = None self.latent_mask = None self.mask_for_overlay = None self.mask_blur = mask_blur @@ -619,7 +662,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.nmask = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, + self.sd_model) crop_region = None if self.image_mask is not None: @@ -628,7 +672,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpainting_mask_invert: self.image_mask = ImageOps.invert(self.image_mask) - #self.image_unblurred_mask = self.image_mask + # self.image_unblurred_mask = self.image_mask if self.mask_blur > 0: self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) @@ -642,7 +686,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask = mask.crop(crop_region) self.image_mask = images.resize_image(2, mask, self.width, self.height) - self.paste_to = (x1, y1, x2-x1, y2-y1) + self.paste_to = (x1, y1, x2 - x1, y2 - y1) else: self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height) np_mask = np.array(self.image_mask) @@ -665,7 +709,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.image_mask is not None: image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + image_masked.paste(image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) self.overlay_images.append(image_masked.convert('RGBA')) @@ -714,12 +759,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # this needs to be fixed to be done in sample() using actual seeds for batches if self.inpainting_fill == 2: - self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask + self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], + all_seeds[ + 0:self.init_latent.shape[ + 0]]) * self.nmask elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, + subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6d5196fe..192883b2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -14,7 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model -from transformers import CLIPVisionModel, CLIPModel +from tqdm import trange +from transformers import CLIPVisionModel, CLIPModel, CLIPTokenizer import torch.optim as optim import copy @@ -22,21 +23,25 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward + def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and ( + 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): + elif not cmd_opts.disable_opt_split_attention and ( + cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): if not invokeAI_mps_available and shared.device.type == 'mps': - print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") + print( + "The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 else: @@ -112,14 +117,16 @@ class StableDiffusionModelHijack: _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) + def slerp(low, high, val): - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm*high_norm).sum(1)) + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm * high_norm).sum(1)) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() @@ -128,6 +135,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.wrapped.transformer.name_or_path ) del self.clipModel.vision_model + self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path) self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() @@ -139,7 +147,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] - tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if + '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 for c in text: @@ -155,8 +164,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None, - aesthetic_slerp=True): + def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, + aesthetic_slerp=True, aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False): + self.aesthetic_imgs_text = aesthetic_imgs_text + self.aesthetic_slerp_angle = aesthetic_slerp_angle + self.aesthetic_text_negative = aesthetic_text_negative self.slerp = aesthetic_slerp self.aesthetic_lr = aesthetic_lr self.aesthetic_weight = aesthetic_weight @@ -180,7 +194,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: parsed = [[line, 1.0]] - tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] + tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)[ + "input_ids"] fixes = [] remade_tokens = [] @@ -196,18 +211,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if token == self.comma_token: last_comma = len(remade_tokens) - elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: + elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), + 1) % 75 == 0 and last_comma != -1 and len( + remade_tokens) - last_comma <= opts.comma_padding_backtrack: last_comma += 1 reloc_tokens = remade_tokens[last_comma:] reloc_mults = multipliers[last_comma:] remade_tokens = remade_tokens[:last_comma] length = len(remade_tokens) - + rem = int(math.ceil(length / 75)) * 75 - length remade_tokens += [id_end] * rem + reloc_tokens multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults - + if embedding is None: remade_tokens.append(token) multipliers.append(weight) @@ -248,7 +265,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if line in cache: remade_tokens, fixes, multipliers = cache[line] else: - remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, + hijack_comments) token_count = max(current_token_count, token_count) cache[line] = (remade_tokens, fixes, multipliers) @@ -259,7 +277,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id @@ -289,7 +306,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, + i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: @@ -312,11 +330,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ovf = remade_tokens[maxlen - 2:] overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + hijack_comments.append( + f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) @@ -326,23 +345,26 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): hijack_fixes.append(fixes) batch_multipliers.append(multipliers) return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - + def forward(self, text): use_old = opts.use_old_emphasis_implementation if use_old: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old( + text) else: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text( + text) self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - + self.hijack.comments.append( + "Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + if use_old: self.hijack.fixes = hijack_fixes return self.process_tokens(remade_batch_tokens, batch_multipliers) - + z = None i = 0 while max(map(len, remade_batch_tokens)) != 0: @@ -356,7 +378,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if fix[0] == i: fixes.append(fix[1]) self.hijack.fixes.append(fixes) - + tokens = [] multipliers = [] for j in range(len(remade_batch_tokens)): @@ -378,19 +400,30 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens).to(device) + + model = copy.deepcopy(self.clipModel).to(device) + model.requires_grad_(True) + if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: + text_embs_2 = model.get_text_features( + **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) + if self.aesthetic_text_negative: + text_embs_2 = self.image_embs - text_embs_2 + text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) + img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) + else: + img_embs = self.image_embs + with torch.enable_grad(): - model = copy.deepcopy(self.clipModel).to(device) - model.requires_grad_(True) # We optimize the model to maximize the similarity optimizer = optim.Adam( model.text_model.parameters(), lr=self.aesthetic_lr ) - for i in range(self.aesthetic_steps): + for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"): text_embs = model.get_text_features(input_ids=tokens) text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ self.image_embs.T + sim = text_embs @ img_embs.T loss = -sim optimizer.zero_grad() loss.mean().backward() @@ -405,6 +438,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): model.cpu() del model + zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1) if self.slerp: z = slerp(z, zn, self.aesthetic_weight) else: @@ -413,15 +447,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 - + return z - - + def process_tokens(self, remade_batch_tokens, batch_multipliers): if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] + remade_batch_tokens = [ + [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in + remade_batch_tokens] batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] - + tokens = torch.asarray(remade_batch_tokens).to(device) outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) @@ -461,8 +496,8 @@ class EmbeddingsWithFixes(torch.nn.Module): for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: emb = embedding.vec - emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) - tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) + emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) vecs.append(tensor) diff --git a/modules/shared.py b/modules/shared.py index cf13a10d..7cd608ca 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -95,6 +95,10 @@ loaded_hypernetwork = None aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} +def update_aesthetic_embeddings(): + global aesthetic_embeddings + aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in + os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} def reload_hypernetworks(): global hypernetworks diff --git a/modules/txt2img.py b/modules/txt2img.py index 78342024..eedcdfe0 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -13,7 +13,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, - aesthetic_slerp=False, *args): + aesthetic_slerp=False, + aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False, + *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -47,7 +51,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: - processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp) + processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative) shared.total_tqdm.clear() diff --git a/modules/ui.py b/modules/ui.py index d961d126..e98e2113 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -41,6 +41,7 @@ from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui +import modules.aesthetic_clip # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -449,7 +450,7 @@ def create_toprow(is_img2img): with gr.Row(): negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) with gr.Column(scale=1, elem_id="roll_col"): - sh = gr.Button(elem_id="sh", visible=True) + sh = gr.Button(elem_id="sh", visible=True) with gr.Column(scale=1, elem_id="style_neg_col"): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) @@ -536,9 +537,13 @@ def create_ui(wrap_gradio_gpu_call): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with gr.Group(): - aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50) + aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001") + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) + aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5) + with gr.Row(): + aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") + aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) + aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) @@ -617,7 +622,10 @@ def create_ui(wrap_gradio_gpu_call): aesthetic_weight, aesthetic_steps, aesthetic_imgs, - aesthetic_slerp + aesthetic_slerp, + aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative ] + custom_inputs, outputs=[ txt2img_gallery, @@ -721,7 +729,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) - inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) + inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=1024, step=4, value=32) with gr.TabItem('Batch img2img', id='batch'): hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' @@ -1071,6 +1079,17 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create embedding", variant='primary') + with gr.Tab(label="Create images embedding"): + new_embedding_name_ae = gr.Textbox(label="Name") + process_src_ae = gr.Textbox(label='Source directory') + batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256) + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding_ae = gr.Button(value="Create images embedding", variant='primary') + with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) @@ -1139,7 +1158,7 @@ def create_ui(wrap_gradio_gpu_call): fn=modules.textual_inversion.ui.create_embedding, inputs=[ new_embedding_name, - initialization_text, + process_src, nvpt, ], outputs=[ @@ -1149,6 +1168,20 @@ def create_ui(wrap_gradio_gpu_call): ] ) + create_embedding_ae.click( + fn=modules.aesthetic_clip.generate_imgs_embd, + inputs=[ + new_embedding_name_ae, + process_src_ae, + batch_ae + ], + outputs=[ + aesthetic_imgs, + ti_output, + ti_outcome, + ] + ) + create_hypernetwork.click( fn=modules.hypernetworks.ui.create_hypernetwork, inputs=[ -- cgit v1.2.3 From 5fd638f14d75a71a37157ded5d33c716ab9eb8ca Mon Sep 17 00:00:00 2001 From: ruocaled Date: Sat, 15 Oct 2022 02:00:46 -0700 Subject: fix download section layout --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index c9b53247..3206113e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -619,7 +619,7 @@ def create_ui(wrap_gradio_gpu_call): txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4) - with gr.Group(): + with gr.Column(): with gr.Row(): save = gr.Button('Save') send_to_img2img = gr.Button('Send to img2img') @@ -834,7 +834,7 @@ def create_ui(wrap_gradio_gpu_call): img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4) - with gr.Group(): + with gr.Column(): with gr.Row(): save = gr.Button('Save') img2img_send_to_img2img = gr.Button('Send to img2img') -- cgit v1.2.3 From 703e6d9e4e161d36b9328eefb5200e1c44fb4afd Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sat, 15 Oct 2022 21:47:08 +0900 Subject: check NaN for hypernetwork tuning --- modules/hypernetworks/hypernetwork.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index a2b3bc0a..4905710e 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -272,15 +272,17 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log optimizer.zero_grad() loss.backward() optimizer.step() - - pbar.set_description(f"loss: {losses.mean():.7f}") + mean_loss = losses.mean() + if torch.isnan(mean_loss): + raise RuntimeError("Loss diverged.") + pbar.set_description(f"loss: {mean_loss:.7f}") if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') hypernetwork.save(last_saved_file) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { - "loss": f"{losses.mean():.7f}", + "loss": f"{mean_loss:.7f}", "learn_rate": scheduler.learn_rate }) @@ -328,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = f"""

-Loss: {losses.mean():.7f}
+Loss: {mean_loss:.7f}
Step: {hypernetwork.step}
Last prompt: {html.escape(entries[0].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
-- cgit v1.2.3 From 9e846083b702a498fdb60accd72f075fa26701d9 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 14 Oct 2022 14:50:25 +0100 Subject: add vector size to embed text --- modules/textual_inversion/textual_inversion.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e754747e..6f549d62 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -327,10 +327,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc info.add_text("sd-ti-embedding", embedding_to_b64(data)) title = "<{}>".format(data.get('name', '???')) + + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' + checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}'.format(embedding.step) + footer_right = 'v{} {}s'.format(vectorSize, embedding.step) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) captioned_image = insert_image_data_embed(captioned_image, data) -- cgit v1.2.3 From 939f16529a72fe48c2ce3ef31bdaba785925a33c Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 14 Oct 2022 14:55:05 +0100 Subject: only save 1 image per embedding --- modules/textual_inversion/textual_inversion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6f549d62..1d697c90 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -242,6 +242,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc last_saved_file = "" last_saved_image = "" + embedding_yet_to_be_embedded = False ititial_step = embedding.step or 0 if ititial_step > steps: @@ -281,6 +282,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') embedding.save(last_saved_file) + embedding_yet_to_be_embedded = True write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { "loss": f"{losses.mean():.7f}", @@ -318,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc shared.state.current_image = image - if save_image_with_stored_embedding and os.path.exists(last_saved_file): + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') @@ -342,6 +344,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc captioned_image = insert_image_data_embed(captioned_image, data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False image.save(last_saved_image) -- cgit v1.2.3 From 9a1dcd78edbf9caf68b9e6286d7b5ca81500e243 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 14 Oct 2022 18:14:02 +0100 Subject: add webp for embed load --- modules/textual_inversion/textual_inversion.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 1d697c90..c07bffc3 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -96,6 +96,10 @@ class EmbeddingDatabase: else: data = extract_image_data_embed(embed_image) name = data.get('name', name) + elif filename.upper().endswith('.WEBP'): + embed_image = Image.open(path) + data = extract_image_data_embed(embed_image) + name = data.get('name', name) else: data = torch.load(path, map_location="cpu") -- cgit v1.2.3 From ddf6899df0cf87d4da77cb2ce223061f4a5edf18 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 14 Oct 2022 18:23:20 +0100 Subject: generalise to popular lossless formats --- modules/textual_inversion/textual_inversion.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c07bffc3..b99df3b1 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -88,18 +88,14 @@ class EmbeddingDatabase: data = [] - if filename.upper().endswith('.PNG'): + if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']: embed_image = Image.open(path) - if 'sd-ti-embedding' in embed_image.text: + if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: data = embedding_from_b64(embed_image.text['sd-ti-embedding']) name = data.get('name', name) else: data = extract_image_data_embed(embed_image) name = data.get('name', name) - elif filename.upper().endswith('.WEBP'): - embed_image = Image.open(path) - data = extract_image_data_embed(embed_image) - name = data.get('name', name) else: data = torch.load(path, map_location="cpu") -- cgit v1.2.3 From b6e3b96dab94a00f51725f9cc977eebc6b4072ab Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 15 Oct 2022 15:17:21 +0100 Subject: Change vector size footer label --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b99df3b1..2ed345b1 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -338,7 +338,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = 'v{} {}s'.format(vectorSize, embedding.step) + footer_right = '{}v {}s'.format(vectorSize, embedding.step) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) captioned_image = insert_image_data_embed(captioned_image, data) -- cgit v1.2.3 From 4387e4fe6479c08f7bc7e42924c3a1093e3a1872 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 18:39:29 +0200 Subject: Update modules/ui.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Gallego --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d0696101..5bb961b2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -599,7 +599,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001") aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5) + aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) + with gr.Row(): aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) -- cgit v1.2.3 From 9b7705e0573bddde26df4575c71f994d73a4d519 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 18:40:34 +0200 Subject: Update modules/aesthetic_clip.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Gallego --- modules/aesthetic_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index f15cfd47..bcf2b073 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -70,7 +70,7 @@ def generate_imgs_embd(name, folder, batch_size): torch.cuda.empty_cache() res = f""" Done generating embedding for {name}! - Hypernetwork saved to {html.escape(path)} + Aesthetic embedding saved to {html.escape(path)} """ shared.update_aesthetic_embeddings() return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", -- cgit v1.2.3 From 0d4f5db235357aeb4c7a8738179ba33aaf5a6b75 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 18:40:58 +0200 Subject: Update modules/ui.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Gallego --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 5bb961b2..25eba548 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -597,7 +597,8 @@ def create_ui(wrap_gradio_gpu_call): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with gr.Group(): - aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001") + aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) -- cgit v1.2.3 From ad9bc604a8fadcfebe72be37f66cec51e7e87fb5 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 18:41:18 +0200 Subject: Update modules/ui.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Gallego --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 25eba548..3b28b69c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -607,7 +607,8 @@ def create_ui(wrap_gradio_gpu_call): aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) - aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) + aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Aesthetic imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) + aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) with gr.Row(): -- cgit v1.2.3 From 3f5c3b981e46c16bb10948d012575b25170efb3b Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 18:41:46 +0200 Subject: Update modules/ui.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Gallego --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 3b28b69c..1f6fcdc9 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1190,7 +1190,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create embedding", variant='primary') - with gr.Tab(label="Create images embedding"): + with gr.Tab(label="Create aesthetic images embedding"): + new_embedding_name_ae = gr.Textbox(label="Name") process_src_ae = gr.Textbox(label='Source directory') batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256) -- cgit v1.2.3 From 74a9ee70020ffa2746c82300c533de3f7e523f22 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 17:25:35 +0300 Subject: fix saving images compatibility with gradio update --- modules/ui.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 3206113e..b867d40f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -154,10 +154,7 @@ def save_files(js_data, images, do_make_zip, index): writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) for image_index, filedata in enumerate(images, start_index): - if filedata.startswith("data:image/png;base64,"): - filedata = filedata[len("data:image/png;base64,"):] - - image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8')))) + image = image_from_url_text(filedata) is_grid = image_index < p.index_of_first_image i = 0 if is_grid else (image_index - p.index_of_first_image) -- cgit v1.2.3 From 529afbf4d70165a0dfd19eb9c2ec22416b794a1d Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 15 Oct 2022 19:19:54 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c81722a0..984b35c4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -24,7 +24,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward -- cgit v1.2.3 From 9a33292ce41b01252cdb8ab6214a11d274e32fa0 Mon Sep 17 00:00:00 2001 From: zhengxiaoyao0716 <1499383852@qq.com> Date: Sat, 15 Oct 2022 01:04:47 +0800 Subject: reload javascript files when custom script bodies --- modules/ui.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index b867d40f..90b8646b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,7 +12,7 @@ import time import traceback import platform import subprocess as sp -from functools import reduce +from functools import partial, reduce import numpy as np import torch @@ -1491,6 +1491,7 @@ Requested path was: {f} def reload_scripts(): modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page reload_script_bodies.click( fn=reload_scripts, @@ -1738,22 +1739,25 @@ Requested path was: {f} return demo -with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' +def load_javascript(raw_response): + with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: + javascript = f'' -jsdir = os.path.join(script_path, "javascript") -for filename in sorted(os.listdir(jsdir)): - with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: - javascript += f"\n" + jsdir = os.path.join(script_path, "javascript") + for filename in sorted(os.listdir(jsdir)): + with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: + javascript += f"\n" - -if 'gradio_routes_templates_response' not in globals(): def template_response(*args, **kwargs): - res = gradio_routes_templates_response(*args, **kwargs) - res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) + res = raw_response(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) res.init_headers() return res - gradio_routes_templates_response = gradio.routes.templates.TemplateResponse gradio.routes.templates.TemplateResponse = template_response + +reload_javascript = partial(load_javascript, + gradio.routes.templates.TemplateResponse) +reload_javascript() -- cgit v1.2.3 From 3d21684ee30ca5734126b8d08c05b3a0f513fe75 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 00:01:00 +0200 Subject: Add support to other img format, fixed dropbox update --- modules/aesthetic_clip.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index bcf2b073..68264284 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -8,7 +8,7 @@ import gradio as gr import torch from PIL import Image from modules import shared -from modules.shared import device, aesthetic_embeddings +from modules.shared import device from transformers import CLIPModel, CLIPProcessor from tqdm.auto import tqdm @@ -20,7 +20,7 @@ def get_all_images_in_folder(folder): def check_is_valid_image_file(filename): - return filename.lower().endswith(('.png', '.jpg', '.jpeg')) + return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp")) def batched(dataset, total, n=1): @@ -73,6 +73,6 @@ def generate_imgs_embd(name, folder, batch_size): Aesthetic embedding saved to {html.escape(path)} """ shared.update_aesthetic_embeddings() - return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", - value=sorted(aesthetic_embeddings.keys())[0] if len( - aesthetic_embeddings) > 0 else None), res, "" + return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", + value=sorted(shared.aesthetic_embeddings.keys())[0] if len( + shared.aesthetic_embeddings) > 0 else None), res, "" -- cgit v1.2.3 From 9325c85f780c569d1823e422eaf51b2e497e0d3e Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 00:23:47 +0200 Subject: fixed dropbox update --- modules/sd_hijack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 192883b2..491312b4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -9,7 +9,7 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared -from modules.shared import opts, device, cmd_opts, aesthetic_embeddings +from modules.shared import opts, device, cmd_opts from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention @@ -182,7 +182,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): image_embs_name = None if image_embs_name is not None and self.image_embs_name != image_embs_name: self.image_embs_name = image_embs_name - self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device) + self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) self.image_embs.requires_grad_(False) -- cgit v1.2.3 From 763b893f319cee280b86e63025eb55e7c16b02e7 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sun, 16 Oct 2022 10:03:09 +0800 Subject: images history sorting files by date --- modules/images_history.py | 261 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 196 insertions(+), 65 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index f5ef44fe..533cf51b 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -1,33 +1,74 @@ import os import shutil +import time +import hashlib +import gradio +show_max_dates_num = 3 +system_bak_path = "webui_log_and_bak" +def is_valid_date(date): + try: + time.strptime(date, "%Y%m%d") + return True + except: + return False +def reduplicative_file_move(src, dst): + def same_name_file(basename, path): + name, ext = os.path.splitext(basename) + f_list = os.listdir(path) + max_num = 0 + for f in f_list: + if len(f) <= len(basename): + continue + f_ext = f[-len(ext):] if len(ext) > 0 else "" + if f[:len(name)] == name and f_ext == ext: + if f[len(name)] == "(" and f[-len(ext)-1] == ")": + number = f[len(name)+1:-len(ext)-1] + if number.isdigit(): + if int(number) > max_num: + max_num = int(number) + return f"{name}({max_num + 1}){ext}" + name = os.path.basename(src) + save_name = os.path.join(dst, name) + if not os.path.exists(save_name): + shutil.move(src, dst) + else: + name = same_name_file(name, dst) + shutil.move(src, os.path.join(dst, name)) -def traverse_all_files(output_dir, image_list, curr_dir=None): - curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir) +def traverse_all_files(curr_path, image_list, all_type=False): try: f_list = os.listdir(curr_path) except: - if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt": - image_list.append(curr_dir) + if all_type or curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt": + image_list.append(curr_path) return image_list for file in f_list: - file = file if curr_dir is None else os.path.join(curr_dir, file) - file_path = os.path.join(curr_path, file) - if file[-4:] == ".txt": + file = os.path.join(curr_path, file) + if (not all_type) and file[-4:] == ".txt": pass - elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0: + elif os.path.isfile(file) and file[-10:].rfind(".") > 0: image_list.append(file) else: - image_list = traverse_all_files(output_dir, image_list, file) + image_list = traverse_all_files(file, image_list) return image_list - -def get_recent_images(dir_name, page_index, step, image_index, tabname): - page_index = int(page_index) - f_list = os.listdir(dir_name) +def get_recent_images(dir_name, page_index, step, image_index, tabname, date_from, date_to): + #print(f"turn_page {page_index}",date_from) + if date_from is None or date_from == "": + return None, 1, None, "" image_list = [] - image_list = traverse_all_files(dir_name, image_list) - image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) + date_list = auto_sorting(dir_name) + page_index = int(page_index) + today = time.strftime("%Y%m%d",time.localtime(time.time())) + for date in date_list: + if date >= date_from and date <= date_to: + path = os.path.join(dir_name, date) + if date == today and not os.path.exists(path): + continue + image_list = traverse_all_files(path, image_list) + + image_list = sorted(image_list, key=lambda file: -os.path.getctime(file)) num = 48 if tabname != "extras" else 12 max_page_index = len(image_list) // num + 1 page_index = max_page_index if page_index == -1 else page_index + step @@ -38,40 +79,101 @@ def get_recent_images(dir_name, page_index, step, image_index, tabname): image_index = int(image_index) if image_index < 0 or image_index > len(image_list) - 1: current_file = None - hidden = None else: - current_file = image_list[int(image_index)] - hidden = os.path.join(dir_name, current_file) - return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, "" + current_file = image_list[image_index] + return image_list, page_index, image_list, "" +def auto_sorting(dir_name): + #print(f"auto sorting") + bak_path = os.path.join(dir_name, system_bak_path) + if not os.path.exists(bak_path): + os.mkdir(bak_path) + log_file = None + files_list = [] + f_list = os.listdir(dir_name) + for file in f_list: + if file == system_bak_path: + continue + file_path = os.path.join(dir_name, file) + if not is_valid_date(file): + if file[-10:].rfind(".") > 0: + files_list.append(file_path) + else: + files_list = traverse_all_files(file_path, files_list, all_type=True) + + for file in files_list: + date_str = time.strftime("%Y%m%d",time.localtime(os.path.getctime(file))) + file_path = os.path.dirname(file) + hash_path = hashlib.md5(file_path.encode()).hexdigest() + path = os.path.join(dir_name, date_str, hash_path) + if not os.path.exists(path): + os.makedirs(path) + if log_file is None: + log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a") + log_file.write(f"{hash_path},{file_path}\n") + reduplicative_file_move(file, path) + + date_list = [] + f_list = os.listdir(dir_name) + for f in f_list: + if is_valid_date(f): + date_list.append(f) + elif f == system_bak_path: + continue + else: + reduplicative_file_move(os.path.join(dir_name, f), bak_path) + + today = time.strftime("%Y%m%d",time.localtime(time.time())) + if today not in date_list: + date_list.append(today) + return sorted(date_list) -def first_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, 1, 0, image_index, tabname) -def end_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, -1, 0, image_index, tabname) +def archive_images(dir_name): + date_list = auto_sorting(dir_name) + date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0] + return ( + gradio.update(visible=False), + gradio.update(visible=True), + gradio.Dropdown.update(choices=date_list, value=date_list[-1]), + gradio.Dropdown.update(choices=date_list, value=date_from) + ) +def date_to_change(dir_name, page_index, image_index, tabname, date_from, date_to): + #print("date_to", date_to) + date_list = auto_sorting(dir_name) + date_from_list = [date for date in date_list if date <= date_to] + date_from = date_from_list[0] if len(date_from_list) < show_max_dates_num else date_from_list[-show_max_dates_num] + image_list, page_index, image_list, _ =get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to) + return image_list, page_index, image_list, _, gradio.Dropdown.update(choices=date_from_list, value=date_from) -def prev_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, -1, image_index, tabname) +def first_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): + return get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to) -def next_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, 1, image_index, tabname) +def end_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): + return get_recent_images(dir_name, -1, 0, image_index, tabname, date_from, date_to) -def page_index_change(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, 0, image_index, tabname) +def prev_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): + return get_recent_images(dir_name, page_index, -1, image_index, tabname, date_from, date_to) -def show_image_info(num, image_path, filenames): - # print(f"select image {num}") - file = filenames[int(num)] - return file, num, os.path.join(image_path, file) +def next_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): + return get_recent_images(dir_name, page_index, 1, image_index, tabname, date_from, date_to) + + +def page_index_change(dir_name, page_index, image_index, tabname, date_from, date_to): + return get_recent_images(dir_name, page_index, 0, image_index, tabname, date_from, date_to) -def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index): +def show_image_info(tabname_box, num, filenames): + # #print(f"select image {num}") + file = filenames[int(num)] + return file, num, file + +def delete_image(delete_num, tabname, name, page_index, filenames, image_index): if name == "": return filenames, delete_num else: @@ -81,21 +183,19 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima new_file_list = [] for name in filenames: if i >= index and i < index + delete_num: - path = os.path.join(dir_name, name) - if os.path.exists(path): - print(f"Delete file {path}") - os.remove(path) - txt_file = os.path.splitext(path)[0] + ".txt" + if os.path.exists(name): + #print(f"Delete file {name}") + os.remove(name) + txt_file = os.path.splitext(name)[0] + ".txt" if os.path.exists(txt_file): os.remove(txt_file) else: - print(f"Not exists file {path}") + #print(f"Not exists file {name}") else: new_file_list.append(name) i += 1 return new_file_list, 1 - def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): if tabname == "txt2img": dir_name = opts.outdir_txt2img_samples @@ -107,16 +207,32 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): dir_name = d[0] for p in d[1:]: dir_name = os.path.join(dir_name, p) - with gr.Row(): - renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page") - first_page = gr.Button('First Page') - prev_page = gr.Button('Prev Page') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next Page') - end_page = gr.Button('End Page') - with gr.Row(elem_id=tabname + "_images_history"): + + f_list = os.listdir(dir_name) + sorted_flag = os.path.exists(os.path.join(dir_name, system_bak_path)) or len(f_list) == 0 + date_list, date_from, date_to = None, None, None + if sorted_flag: + #print(sorted_flag) + date_list = auto_sorting(dir_name) + date_to = date_list[-1] + date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0] + + with gr.Column(visible=sorted_flag) as page_panel: with gr.Row(): + renew_page = gr.Button('Refresh', elem_id=tabname + "_images_history_renew_page", interactive=sorted_flag) + first_page = gr.Button('First Page') + prev_page = gr.Button('Prev Page') + page_index = gr.Number(value=1, label="Page Index") + next_page = gr.Button('Next Page') + end_page = gr.Button('End Page') + + with gr.Row(elem_id=tabname + "_images_history"): with gr.Column(scale=2): + with gr.Row(): + newest = gr.Button('Newest') + date_to = gr.Dropdown(choices=date_list, value=date_to, label="Date to") + date_from = gr.Dropdown(choices=date_list, value=date_from, label="Date from") + history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) with gr.Row(): delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") @@ -128,22 +244,31 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Row(): with gr.Column(): img_file_info = gr.Textbox(label="Generate Info", interactive=False) - img_file_name = gr.Textbox(label="File Name", interactive=False) - with gr.Row(): + img_file_name = gr.Textbox(value="", label="File Name", interactive=False) # hiden items + with gr.Row(visible=False): + img_path = gr.Textbox(dir_name) + tabname_box = gr.Textbox(tabname) + image_index = gr.Textbox(value=-1) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") + filenames = gr.State() + hidden = gr.Image(type="pil") + info1 = gr.Textbox() + info2 = gr.Textbox() + with gr.Column(visible=not sorted_flag) as init_warning: + with gr.Row(): + gr.Textbox("The system needs to archive the files according to the date. This requires changing the directory structure of the files", + label="Waring", + css="") + with gr.Row(): + sorted_button = gr.Button('Confirme') - img_path = gr.Textbox(dir_name.rstrip("/"), visible=False) - tabname_box = gr.Textbox(tabname, visible=False) - image_index = gr.Textbox(value=-1, visible=False) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) - filenames = gr.State() - hidden = gr.Image(type="pil", visible=False) - info1 = gr.Textbox(visible=False) - info2 = gr.Textbox(visible=False) - + + + # turn pages - gallery_inputs = [img_path, page_index, image_index, tabname_box] - gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name] + gallery_inputs = [img_path, page_index, image_index, tabname_box, date_from, date_to] + gallery_outputs = [history_gallery, page_index, filenames, img_file_name] first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) @@ -154,15 +279,21 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index]) # other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden]) + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, filenames], outputs=[img_file_name, image_index, hidden]) img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) - delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) + delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) - + date_to.change(date_to_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs + [date_from]) # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') + sorted_button.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from]) + newest.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from]) + + + + def create_history_tabs(gr, opts, run_pnginfo, switch_dict): with gr.Blocks(analytics_enabled=False) as images_history: -- cgit v1.2.3 From 0c5fa9a681672508adadbe1e10fc16d7fe0ed6dd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 16 Oct 2022 08:51:24 +0300 Subject: do not reload embeddings from disk when doing textual inversion --- modules/processing.py | 5 +++-- modules/textual_inversion/textual_inversion.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 941ae089..833fed8a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -53,7 +53,7 @@ def get_correct_sampler(p): return sd_samplers.samplers_for_img2img class StableDiffusionProcessing: - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None, do_not_reload_embeddings=False): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -80,6 +80,7 @@ class StableDiffusionProcessing: self.extra_generation_params: dict = extra_generation_params or {} self.overlay_images = overlay_images self.eta = eta + self.do_not_reload_embeddings = do_not_reload_embeddings self.paste_to = None self.color_corrections = None self.denoising_strength: float = 0 @@ -364,7 +365,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: def infotext(iteration=0, position_in_batch=0): return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) - if os.path.exists(cmd_opts.embeddings_dir): + if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() infotexts = [] diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 2ed345b1..7ec75018 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -296,6 +296,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc sd_model=shared.sd_model, do_not_save_grid=True, do_not_save_samples=True, + do_not_reload_embeddings=True, ) if preview_from_txt2img: -- cgit v1.2.3 From 2ce27728f6433911274efa67856315d22df56629 Mon Sep 17 00:00:00 2001 From: winterspringsummer Date: Sun, 16 Oct 2022 13:50:55 +0900 Subject: added extras batch work from directory --- modules/extras.py | 23 ++++++++++++++++++----- modules/ui.py | 12 ++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index f2f5a7b0..5b52b27d 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -20,26 +20,38 @@ import gradio as gr cached_images = {} -def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): devices.torch_gc() imageArr = [] # Also keep track of original file names imageNameArr = [] - + outputs = [] + if extras_mode == 1: #convert file to pillow image for img in image_folder: image = Image.open(img) imageArr.append(image) imageNameArr.append(os.path.splitext(img.orig_name)[0]) + elif extras_mode == 2: + if input_dir == '': + return outputs, "Please select an input directory.", '' + image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] + for img in image_list: + image = Image.open(img) + imageArr.append(image) + imageNameArr.append(img) else: imageArr.append(image) imageNameArr.append(None) - outpath = opts.outdir_samples or opts.outdir_extras_samples + if extras_mode == 2 and output_dir != '': + outpath = output_dir + else: + outpath = opts.outdir_samples or opts.outdir_extras_samples - outputs = [] + for image, image_name in zip(imageArr, imageNameArr): if image is None: return outputs, "Please select an input image.", '' @@ -112,7 +124,8 @@ def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, image.info = existing_pnginfo image.info["extras"] = info - outputs.append(image) + if extras_mode != 2 or show_extras_results : + outputs.append(image) devices.torch_gc() diff --git a/modules/ui.py b/modules/ui.py index b867d40f..08fa72c6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1016,6 +1016,15 @@ def create_ui(wrap_gradio_gpu_call): with gr.TabItem('Batch Process'): image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") + with gr.TabItem('Batch from Directory'): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, + placeholder="A directory on the same machine where the server is running." + ) + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, + placeholder="Leave blank to save images to the default path." + ) + show_extras_results = gr.Checkbox(label='Show result images', value=True) + with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by'): upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) @@ -1060,6 +1069,9 @@ def create_ui(wrap_gradio_gpu_call): dummy_component, extras_image, image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, -- cgit v1.2.3 From 179e3ca752d0133470fd3ae44153ee0b71450c9f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 16 Oct 2022 09:51:01 +0300 Subject: honor --hide-ui-dir-config option for #2807 --- modules/extras.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 5b52b27d..0819ed37 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -35,6 +35,8 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ imageArr.append(image) imageNameArr.append(os.path.splitext(img.orig_name)[0]) elif extras_mode == 2: + assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + if input_dir == '': return outputs, "Please select an input directory.", '' image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] -- cgit v1.2.3 From 3395ba493f93214cf037d084d45693a37610bd85 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Sun, 16 Oct 2022 09:24:01 +0900 Subject: Allow specifying the region of ngrok. --- modules/ngrok.py | 8 +++++--- modules/shared.py | 1 + modules/ui.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/ngrok.py b/modules/ngrok.py index 7d03a6df..5c5f349a 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -1,12 +1,14 @@ from pyngrok import ngrok, conf, exception -def connect(token, port): +def connect(token, port, region): if token == None: token = 'None' - conf.get_default().auth_token = token + config = conf.PyngrokConfig( + auth_token=token, region=region + ) try: - public_url = ngrok.connect(port).public_url + public_url = ngrok.connect(port, pyngrok_config=config).public_url except exception.PyngrokNgrokError: print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') diff --git a/modules/shared.py b/modules/shared.py index fa30bbb0..dcab0af9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -40,6 +40,7 @@ parser.add_argument("--unload-gfpgan", action='store_true', help="does not do an parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) +parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) diff --git a/modules/ui.py b/modules/ui.py index 08fa72c6..5c0eaf73 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -56,7 +56,7 @@ if not cmd_opts.share and not cmd_opts.listen: if cmd_opts.ngrok != None: import modules.ngrok as ngrok print('ngrok authtoken detected, trying to connect...') - ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860) + ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region) def gr_show(visible=True): -- cgit v1.2.3 From 20bf99052a9d50b5f99d199f4c449ef1ddd6e3cb Mon Sep 17 00:00:00 2001 From: CookieHCl Date: Sun, 16 Oct 2022 04:47:03 +0900 Subject: Make style configurable in ui-config.json --- modules/ui.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 5c0eaf73..78096f27 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -508,9 +508,11 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=1, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + prompt_style.save_to_config = True with gr.Column(scale=1, elem_id="style_neg_col"): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + prompt_style2.save_to_config = True return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button @@ -1739,6 +1741,11 @@ Requested path was: {f} if type(x) == gr.Number: apply_field(x, 'value') + # Since there are many dropdowns that shouldn't be saved, + # we only mark dropdowns that should be saved. + if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False): + apply_field(x, 'value') + visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") visit(extras_interface, loadsave, "extras") -- cgit v1.2.3 From b65a3101ce82b42b4ccc525044548e66cc44ae4a Mon Sep 17 00:00:00 2001 From: CookieHCl Date: Sun, 16 Oct 2022 04:54:53 +0900 Subject: Use default value when dropdown ui setting is bad Default value is the first value of selectables. Particually, None in styles. --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 78096f27..c8e68bd6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1744,7 +1744,7 @@ Requested path was: {f} # Since there are many dropdowns that shouldn't be saved, # we only mark dropdowns that should be saved. if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False): - apply_field(x, 'value') + apply_field(x, 'value', lambda val: val in x.choices) visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") -- cgit v1.2.3 From 9258a33e3755c76922cd47a03cd59419b6426304 Mon Sep 17 00:00:00 2001 From: CookieHCl Date: Sun, 16 Oct 2022 05:09:11 +0900 Subject: Warn when user uses bad ui setting --- modules/ui.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index c8e68bd6..10bdf121 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1717,7 +1717,9 @@ Requested path was: {f} saved_value = ui_settings.get(key, None) if saved_value is None: ui_settings[key] = getattr(obj, field) - elif condition is None or condition(saved_value): + elif condition and not condition(saved_value): + print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + else: setattr(obj, field, saved_value) if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: -- cgit v1.2.3 From 36a0ba357ab0742c3c4a28437b68fb29a235afbe Mon Sep 17 00:00:00 2001 From: Junpeng Qiu Date: Sat, 15 Oct 2022 21:42:52 -0700 Subject: Added Refresh Button to embedding and hypernetwork names in Train Tab Problem everytime I modified pt files in embedding_dir or hypernetwork_dir, I need to restart webui to have the new files shown in the dropdown of Train Tab Solution refactored create_refresh_button out of create_setting_component so we can use this method to create button next to gr.Dropdowns of embedding name and hypernetworks Extra Modification hypernetwork pt are now sorted in alphabetic order --- modules/ui.py | 45 ++++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 10bdf121..ee3d0248 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -568,6 +568,24 @@ def create_ui(wrap_gradio_gpu_call): import modules.img2img import modules.txt2img + def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn = refresh, + inputs = [], + outputs = [refresh_component] + ) + return refresh_button + with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) @@ -1205,8 +1223,12 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") - train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) + with gr.Row(): + train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + with gr.Row(): + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") batch_size = gr.Number(label='Batch size', value=1, precision=0) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") @@ -1357,26 +1379,11 @@ def create_ui(wrap_gradio_gpu_call): if info.refresh is not None: if is_quicksettings: res = comp(label=info.label, value=fun, **(args or {})) - refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key) + refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: with gr.Row(variant="compact"): res = comp(label=info.label, value=fun, **(args or {})) - refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key) - - def refresh(): - info.refresh() - refreshed_args = info.component_args() if callable(info.component_args) else info.component_args - - for k, v in refreshed_args.items(): - setattr(res, k, v) - - return gr.update(**(refreshed_args or {})) - - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[res], - ) + refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: res = comp(label=info.label, value=fun, **(args or {})) -- cgit v1.2.3 From 523140d7805c644700009b8a2483ff4eb4a22304 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 10:23:30 +0200 Subject: ui fix --- modules/aesthetic_clip.py | 3 +-- modules/sd_hijack.py | 3 +-- modules/shared.py | 2 ++ modules/ui.py | 24 ++++++++++++++---------- 4 files changed, 18 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index 68264284..ccb35c73 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -74,5 +74,4 @@ def generate_imgs_embd(name, folder, batch_size): """ shared.update_aesthetic_embeddings() return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", - value=sorted(shared.aesthetic_embeddings.keys())[0] if len( - shared.aesthetic_embeddings) > 0 else None), res, "" + value="None"), res, "" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 01fcb78f..2de2eed5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -392,8 +392,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - if len(text[ - 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: + if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: if not opts.use_old_emphasis_implementation: remade_batch_tokens = [ [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in diff --git a/modules/shared.py b/modules/shared.py index 3c5ffef1..e2c98b2d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -96,11 +96,13 @@ loaded_hypernetwork = None aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} +aesthetic_embeddings = aesthetic_embeddings | {"None": None} def update_aesthetic_embeddings(): global aesthetic_embeddings aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} + aesthetic_embeddings = aesthetic_embeddings | {"None": None} def reload_hypernetworks(): global hypernetworks diff --git a/modules/ui.py b/modules/ui.py index 13ba3142..4069f0d2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -594,19 +594,23 @@ def create_ui(wrap_gradio_gpu_call): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with gr.Group(): - aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") - - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) + with gr.Accordion("Open for Clip Aesthetic!",open=False): + with gr.Row(): + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) + aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) - with gr.Row(): - aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") - aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) - aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + with gr.Row(): + aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") + aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), + label="Aesthetic imgs embedding", + value="None") - aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Aesthetic imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) + with gr.Row(): + aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") + aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) + aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) - aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) -- cgit v1.2.3 From e4f8b5f00dd33b7547cc6b76fbed26bb83b37a64 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 10:28:21 +0200 Subject: ui fix --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2de2eed5..5d0590af 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -178,7 +178,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.load_image_embs(image_embs_name) def load_image_embs(self, image_embs_name): - if image_embs_name is None or len(image_embs_name) == 0: + if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": image_embs_name = None if image_embs_name is not None and self.image_embs_name != image_embs_name: self.image_embs_name = image_embs_name -- cgit v1.2.3 From f62905fdf928b54aa76765e5cbde8d538d494e49 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sun, 16 Oct 2022 21:22:38 +0800 Subject: images history speed up --- modules/images_history.py | 250 ++++++++++++++++++++++++---------------------- 1 file changed, 128 insertions(+), 122 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 7fd75005..ae0b4e40 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -3,8 +3,10 @@ import shutil import time import hashlib import gradio -show_max_dates_num = 3 + system_bak_path = "webui_log_and_bak" +loads_files_num = 216 +num_of_imgs_per_page = 36 def is_valid_date(date): try: time.strptime(date, "%Y%m%d") @@ -53,38 +55,7 @@ def traverse_all_files(curr_path, image_list, all_type=False): image_list = traverse_all_files(file, image_list) return image_list -def get_recent_images(dir_name, page_index, step, image_index, tabname, date_from, date_to): - #print(f"turn_page {page_index}",date_from) - if date_from is None or date_from == "": - return None, 1, None, "" - image_list = [] - date_list = auto_sorting(dir_name) - page_index = int(page_index) - today = time.strftime("%Y%m%d",time.localtime(time.time())) - for date in date_list: - if date >= date_from and date <= date_to: - path = os.path.join(dir_name, date) - if date == today and not os.path.exists(path): - continue - image_list = traverse_all_files(path, image_list) - - image_list = sorted(image_list, key=lambda file: -os.path.getctime(file)) - num = 48 if tabname != "extras" else 12 - max_page_index = len(image_list) // num + 1 - page_index = max_page_index if page_index == -1 else page_index + step - page_index = 1 if page_index < 1 else page_index - page_index = max_page_index if page_index > max_page_index else page_index - idx_frm = (page_index - 1) * num - image_list = image_list[idx_frm:idx_frm + num] - image_index = int(image_index) - if image_index < 0 or image_index > len(image_list) - 1: - current_file = None - else: - current_file = image_list[image_index] - return image_list, page_index, image_list, "" - -def auto_sorting(dir_name): - #print(f"auto sorting") +def auto_sorting(dir_name): bak_path = os.path.join(dir_name, system_bak_path) if not os.path.exists(bak_path): os.mkdir(bak_path) @@ -126,102 +97,131 @@ def auto_sorting(dir_name): today = time.strftime("%Y%m%d",time.localtime(time.time())) if today not in date_list: date_list.append(today) - return sorted(date_list) + return sorted(date_list, reverse=True) -def archive_images(dir_name): +def archive_images(dir_name, date_to): date_list = auto_sorting(dir_name) - date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0] + today = time.strftime("%Y%m%d",time.localtime(time.time())) + date_to = today if date_to is None or date_to == "" else date_to + filenames = [] + for date in date_list: + if date <= date_to: + path = os.path.join(dir_name, date) + if date == today and not os.path.exists(path): + continue + filenames = traverse_all_files(path, filenames) + if len(filenames) > loads_files_num: + break + filenames = sorted(filenames, key=lambda file: -os.path.getctime(file)) + _, image_list, _, visible_num = get_recent_images(1, 0, filenames) return ( gradio.update(visible=False), gradio.update(visible=True), - gradio.Dropdown.update(choices=date_list, value=date_list[-1]), - gradio.Dropdown.update(choices=date_list, value=date_from) + gradio.Dropdown.update(choices=date_list, value=date_to), + date, + filenames, + 1, + image_list, + "", + visible_num ) +def system_init(dir_name): + ret = [x for x in archive_images(dir_name, None)] + ret += [gradio.update(visible=False)] + return ret + +def newest_click(dir_name, date_to): + if date_to == "start": + return True, False, "start", None, None, 1, None, "" + else: + return archive_images(dir_name, time.strftime("%Y%m%d",time.localtime(time.time()))) -def date_to_change(dir_name, page_index, image_index, tabname, date_from, date_to): - #print("date_to", date_to) - date_list = auto_sorting(dir_name) - date_from_list = [date for date in date_list if date <= date_to] - date_from = date_from_list[0] if len(date_from_list) < show_max_dates_num else date_from_list[-show_max_dates_num] - image_list, page_index, image_list, _ =get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to) - return image_list, page_index, image_list, _, gradio.Dropdown.update(choices=date_from_list, value=date_from) - -def first_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): - return get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to) - - -def end_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): - return get_recent_images(dir_name, -1, 0, image_index, tabname, date_from, date_to) - - -def prev_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): - return get_recent_images(dir_name, page_index, -1, image_index, tabname, date_from, date_to) - - -def next_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): - return get_recent_images(dir_name, page_index, 1, image_index, tabname, date_from, date_to) - - -def page_index_change(dir_name, page_index, image_index, tabname, date_from, date_to): - return get_recent_images(dir_name, page_index, 0, image_index, tabname, date_from, date_to) - - -def show_image_info(tabname_box, num, filenames): - # #print(f"select image {num}") - file = filenames[int(num)] - return file, num, file - -def delete_image(delete_num, tabname, name, page_index, filenames, image_index): +def delete_image(delete_num, name, filenames, image_index, visible_num): if name == "": return filenames, delete_num else: delete_num = int(delete_num) + visible_num = int(visible_num) + image_index = int(image_index) index = list(filenames).index(name) i = 0 new_file_list = [] for name in filenames: if i >= index and i < index + delete_num: if os.path.exists(name): - #print(f"Delete file {name}") + if visible_num == image_index: + new_file_list.append(name) + continue + print(f"Delete file {name}") os.remove(name) + visible_num -= 1 txt_file = os.path.splitext(name)[0] + ".txt" if os.path.exists(txt_file): os.remove(txt_file) else: - #print(f"Not exists file {name}") + print(f"Not exists file {name}") else: new_file_list.append(name) i += 1 - return new_file_list, 1 + return new_file_list, 1, visible_num + +def get_recent_images(page_index, step, filenames): + page_index = int(page_index) + max_page_index = len(filenames) // num_of_imgs_per_page + 1 + page_index = max_page_index if page_index == -1 else page_index + step + page_index = 1 if page_index < 1 else page_index + page_index = max_page_index if page_index > max_page_index else page_index + idx_frm = (page_index - 1) * num_of_imgs_per_page + image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page] + length = len(filenames) + visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page + visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num + return page_index, image_list, "", visible_num + +def first_page_click(page_index, filenames): + return get_recent_images(1, 0, filenames) + +def end_page_click(page_index, filenames): + return get_recent_images(-1, 0, filenames) + +def prev_page_click(page_index, filenames): + return get_recent_images(page_index, -1, filenames) + +def next_page_click(page_index, filenames): + return get_recent_images(page_index, 1, filenames) + +def page_index_change(page_index, filenames): + return get_recent_images(page_index, 0, filenames) + +def show_image_info(tabname_box, num, page_index, filenames): + file = filenames[int(num) + int((page_index - 1) * num_of_imgs_per_page)] + return file, num, file def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): - if opts.outdir_samples != "": - dir_name = opts.outdir_samples - elif tabname == "txt2img": + if tabname == "txt2img": dir_name = opts.outdir_txt2img_samples elif tabname == "img2img": dir_name = opts.outdir_img2img_samples elif tabname == "extras": dir_name = opts.outdir_extras_samples + elif tabname == "saved": + dir_name = opts.outdir_save + if not os.path.exists(dir_name): + os.makedirs(dir_name) d = dir_name.split("/") - dir_name = "/" if dir_name.startswith("/") else d[0] + dir_name = d[0] for p in d[1:]: dir_name = os.path.join(dir_name, p) f_list = os.listdir(dir_name) sorted_flag = os.path.exists(os.path.join(dir_name, system_bak_path)) or len(f_list) == 0 date_list, date_from, date_to = None, None, None - if sorted_flag: - #print(sorted_flag) - date_list = auto_sorting(dir_name) - date_to = date_list[-1] - date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0] with gr.Column(visible=sorted_flag) as page_panel: with gr.Row(): - renew_page = gr.Button('Refresh', elem_id=tabname + "_images_history_renew_page", interactive=sorted_flag) + #renew_page = gr.Button('Refresh') first_page = gr.Button('First Page') prev_page = gr.Button('Prev Page') page_index = gr.Number(value=1, label="Page Index") @@ -231,9 +231,9 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Row(elem_id=tabname + "_images_history"): with gr.Column(scale=2): with gr.Row(): - newest = gr.Button('Newest') - date_to = gr.Dropdown(choices=date_list, value=date_to, label="Date to") - date_from = gr.Dropdown(choices=date_list, value=date_from, label="Date from") + newest = gr.Button('Refresh', elem_id=tabname + "_images_history_start") + date_from = gr.Textbox(label="Date from", interactive=False) + date_to = gr.Dropdown(value="start" if not sorted_flag else None, label="Date to") history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) with gr.Row(): @@ -247,66 +247,72 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Column(): img_file_info = gr.Textbox(label="Generate Info", interactive=False) img_file_name = gr.Textbox(value="", label="File Name", interactive=False) + # hiden items - with gr.Row(visible=False): + with gr.Row(visible=False): + visible_img_num = gr.Number() img_path = gr.Textbox(dir_name) tabname_box = gr.Textbox(tabname) image_index = gr.Textbox(value=-1) set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") filenames = gr.State() + all_images_list = gr.State() hidden = gr.Image(type="pil") info1 = gr.Textbox() info2 = gr.Textbox() + with gr.Column(visible=not sorted_flag) as init_warning: with gr.Row(): - gr.Textbox("The system needs to archive the files according to the date. This requires changing the directory structure of the files", - label="Waring", - css="") + warning = gr.Textbox( + label="Waring", + value=f"The system needs to archive the files according to the date. This requires changing the directory structure of the files.If you have doubts about this operation, you can first back up the files in the '{dir_name}' directory" + ) + warning.style(height=100, width=50) with gr.Row(): sorted_button = gr.Button('Confirme') - - + change_date_output = [init_warning, page_panel, date_to, date_from, filenames, page_index, history_gallery, img_file_name, visible_img_num] + sorted_button.click(system_init, inputs=[img_path], outputs=change_date_output + [sorted_button]) + newest.click(newest_click, inputs=[img_path, date_to], outputs=change_date_output) + date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output) + date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + newest.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + + delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) + delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None) + # turn pages - gallery_inputs = [img_path, page_index, image_index, tabname_box, date_from, date_to] - gallery_outputs = [history_gallery, page_index, filenames, img_file_name] + gallery_inputs = [page_index, filenames] + gallery_outputs = [page_index, history_gallery, img_file_name, visible_img_num] + + first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) - first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index]) + first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") # other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, filenames], outputs=[img_file_name, image_index, hidden]) - img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) - delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden]) + img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) - date_to.change(date_to_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs + [date_from]) - # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) + switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') - sorted_button.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from]) - newest.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from]) - - - def create_history_tabs(gr, opts, run_pnginfo, switch_dict): with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: - with gr.Tab("txt2img history"): - with gr.Blocks(analytics_enabled=False) as images_history_txt2img: - show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict) - with gr.Tab("img2img history"): - with gr.Blocks(analytics_enabled=False) as images_history_img2img: - show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict) - with gr.Tab("extras history"): - with gr.Blocks(analytics_enabled=False) as images_history_img2img: - show_images_history(gr, opts, "extras", run_pnginfo, switch_dict) + for tab in ["saved", "txt2img", "img2img", "extras"]: + with gr.Tab(tab): + with gr.Blocks(analytics_enabled=False) as images_history_img2img: + show_images_history(gr, opts, tab, run_pnginfo, switch_dict) return images_history -- cgit v1.2.3 From 91235d8008372862b1f232f7bf99da310a5955e4 Mon Sep 17 00:00:00 2001 From: CookieHCl Date: Sun, 16 Oct 2022 20:50:24 +0900 Subject: Fix FileNotFoundError in history tab Now only traverse images when directory exists --- modules/images_history.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 9260df8a..e6284142 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -1,6 +1,6 @@ import os import shutil - +import sys def traverse_all_files(output_dir, image_list, curr_dir=None): curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir) @@ -24,10 +24,14 @@ def traverse_all_files(output_dir, image_list, curr_dir=None): def get_recent_images(dir_name, page_index, step, image_index, tabname): page_index = int(page_index) - f_list = os.listdir(dir_name) image_list = [] - image_list = traverse_all_files(dir_name, image_list) - image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) + if not os.path.exists(dir_name): + pass + elif os.path.isdir(dir_name): + image_list = traverse_all_files(dir_name, image_list) + image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) + else: + print(f"ERROR: {dir_name} is not a directory. Check the path in the settings.", file=sys.stderr) num = 48 if tabname != "extras" else 12 max_page_index = len(image_list) // num + 1 page_index = max_page_index if page_index == -1 else page_index + step -- cgit v1.2.3 From c9836279f58461e04c1dda0a86e718f8bd3f41e4 Mon Sep 17 00:00:00 2001 From: CookieHCl Date: Sun, 16 Oct 2022 21:59:05 +0900 Subject: Only make output dir when creating output --- modules/processing.py | 6 ------ modules/ui.py | 5 ++++- 2 files changed, 4 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 833fed8a..deb6125e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -334,12 +334,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: seed = get_fixed_seed(p.seed) subseed = get_fixed_seed(p.subseed) - if p.outpath_samples is not None: - os.makedirs(p.outpath_samples, exist_ok=True) - - if p.outpath_grids is not None: - os.makedirs(p.outpath_grids, exist_ok=True) - modules.sd_hijack.model_hijack.apply_circular(p.tiling) modules.sd_hijack.model_hijack.clear_comments() diff --git a/modules/ui.py b/modules/ui.py index ee3d0248..fa73627a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1394,7 +1394,10 @@ def create_ui(wrap_gradio_gpu_call): component_dict = {} def open_folder(f): - if not os.path.isdir(f): + if not os.path.exists(f): + print(f"{f} doesn't exist. After you create an image, the folder will be created.") + return + elif not os.path.isdir(f): print(f""" WARNING An open_folder request was made with an argument that is not a folder. -- cgit v1.2.3 From adc0ea74e1ee9791f15c3a74bc6c5ad789e10d17 Mon Sep 17 00:00:00 2001 From: CookieHCl Date: Sun, 16 Oct 2022 22:03:18 +0900 Subject: Better readablity of logs --- modules/images_history.py | 2 +- modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index e6284142..e06e07bf 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -31,7 +31,7 @@ def get_recent_images(dir_name, page_index, step, image_index, tabname): image_list = traverse_all_files(dir_name, image_list) image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) else: - print(f"ERROR: {dir_name} is not a directory. Check the path in the settings.", file=sys.stderr) + print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr) num = 48 if tabname != "extras" else 12 max_page_index = len(image_list) // num + 1 page_index = max_page_index if page_index == -1 else page_index + step diff --git a/modules/ui.py b/modules/ui.py index fa73627a..7b0d5a92 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1395,7 +1395,7 @@ def create_ui(wrap_gradio_gpu_call): def open_folder(f): if not os.path.exists(f): - print(f"{f} doesn't exist. After you create an image, the folder will be created.") + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') return elif not os.path.isdir(f): print(f""" -- cgit v1.2.3 From fc220a51cf5bb5bfca83322c16e907a18ec59f6b Mon Sep 17 00:00:00 2001 From: DancingSnow <1121149616@qq.com> Date: Sun, 16 Oct 2022 10:49:21 +0800 Subject: fix dir_path in some path like `D:/Pic/outputs` --- modules/images_history.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index e06e07bf..46b23e56 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -109,10 +109,8 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): dir_name = opts.outdir_img2img_samples elif tabname == "extras": dir_name = opts.outdir_extras_samples - d = dir_name.split("/") - dir_name = "/" if dir_name.startswith("/") else d[0] - for p in d[1:]: - dir_name = os.path.join(dir_name, p) + else: + return with gr.Row(): renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page") first_page = gr.Button('First Page') -- cgit v1.2.3 From a4de699e3c235d83b5a957d08779cb41cb0781bc Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sun, 16 Oct 2022 22:37:12 +0800 Subject: Images history speed up --- modules/images_history.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index ae0b4e40..94bd16a8 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -153,6 +153,7 @@ def delete_image(delete_num, name, filenames, image_index, visible_num): if os.path.exists(name): if visible_num == image_index: new_file_list.append(name) + i += 1 continue print(f"Delete file {name}") os.remove(name) @@ -221,7 +222,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Column(visible=sorted_flag) as page_panel: with gr.Row(): - #renew_page = gr.Button('Refresh') + renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") first_page = gr.Button('First Page') prev_page = gr.Button('Prev Page') page_index = gr.Number(value=1, label="Page Index") @@ -231,7 +232,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Row(elem_id=tabname + "_images_history"): with gr.Column(scale=2): with gr.Row(): - newest = gr.Button('Refresh', elem_id=tabname + "_images_history_start") + newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") date_from = gr.Textbox(label="Date from", interactive=False) date_to = gr.Dropdown(value="start" if not sorted_flag else None, label="Date to") @@ -291,12 +292,14 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs) page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) + renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") # other funcitons set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden]) -- cgit v1.2.3 From 9324cdaa3199d65c182858785dd1eca42b192b8e Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 17:53:56 +0200 Subject: ui fix, re organization of the code --- modules/aesthetic_clip.py | 154 +++++++++++++++++++++++++++++++++-- modules/img2img.py | 14 +++- modules/processing.py | 29 ++----- modules/sd_hijack.py | 102 ++--------------------- modules/sd_models.py | 5 +- modules/shared.py | 14 +++- modules/textual_inversion/dataset.py | 2 +- modules/txt2img.py | 18 ++-- modules/ui.py | 52 +++++++----- 9 files changed, 233 insertions(+), 157 deletions(-) (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index ccb35c73..34efa931 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -1,3 +1,4 @@ +import copy import itertools import os from pathlib import Path @@ -7,11 +8,12 @@ import gc import gradio as gr import torch from PIL import Image -from modules import shared -from modules.shared import device -from transformers import CLIPModel, CLIPProcessor +from torch import optim -from tqdm.auto import tqdm +from modules import shared +from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer +from tqdm.auto import tqdm, trange +from modules.shared import opts, device def get_all_images_in_folder(folder): @@ -37,12 +39,39 @@ def iter_to_batched(iterable, n=1): yield chunk +def create_ui(): + with gr.Group(): + with gr.Accordion("Open for Clip Aesthetic!", open=False): + with gr.Row(): + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", + value=0.9) + aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) + + with gr.Row(): + aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', + placeholder="Aesthetic learning rate", value="0.0001") + aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()), + label="Aesthetic imgs embedding", + value="None") + + with gr.Row(): + aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', + placeholder="This text is used to rotate the feature space of the imgs embs", + value="") + aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01, + value=0.1) + aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + + return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative + + def generate_imgs_embd(name, folder, batch_size): # clipModel = CLIPModel.from_pretrained( # shared.sd_model.cond_stage_model.clipModel.name_or_path # ) - model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device) - processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path) + model = shared.clip_model.to(device) + processor = CLIPProcessor.from_pretrained(model.name_or_path) with torch.no_grad(): embs = [] @@ -63,7 +92,6 @@ def generate_imgs_embd(name, folder, batch_size): torch.save(embs, path) model = model.cpu() - del model del processor del embs gc.collect() @@ -74,4 +102,114 @@ def generate_imgs_embd(name, folder, batch_size): """ shared.update_aesthetic_embeddings() return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", - value="None"), res, "" + value="None"), \ + gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), + label="Imgs embedding", + value="None"), res, "" + + +def slerp(low, high, val): + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm * high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high + return res + + +class AestheticCLIP: + def __init__(self): + self.skip = False + self.aesthetic_steps = 0 + self.aesthetic_weight = 0 + self.aesthetic_lr = 0 + self.slerp = False + self.aesthetic_text_negative = "" + self.aesthetic_slerp_angle = 0 + self.aesthetic_imgs_text = "" + + self.image_embs_name = None + self.image_embs = None + self.load_image_embs(None) + + def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, + aesthetic_slerp=True, aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False): + self.aesthetic_imgs_text = aesthetic_imgs_text + self.aesthetic_slerp_angle = aesthetic_slerp_angle + self.aesthetic_text_negative = aesthetic_text_negative + self.slerp = aesthetic_slerp + self.aesthetic_lr = aesthetic_lr + self.aesthetic_weight = aesthetic_weight + self.aesthetic_steps = aesthetic_steps + self.load_image_embs(image_embs_name) + + def set_skip(self, skip): + self.skip = skip + + def load_image_embs(self, image_embs_name): + if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": + image_embs_name = None + self.image_embs_name = None + if image_embs_name is not None and self.image_embs_name != image_embs_name: + self.image_embs_name = image_embs_name + self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) + self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) + self.image_embs.requires_grad_(False) + + def __call__(self, z, remade_batch_tokens): + if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None: + tokenizer = shared.sd_model.cond_stage_model.tokenizer + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [ + [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in + remade_batch_tokens] + + tokens = torch.asarray(remade_batch_tokens).to(device) + + model = copy.deepcopy(shared.clip_model).to(device) + model.requires_grad_(True) + if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: + text_embs_2 = model.get_text_features( + **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) + if self.aesthetic_text_negative: + text_embs_2 = self.image_embs - text_embs_2 + text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) + img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) + else: + img_embs = self.image_embs + + with torch.enable_grad(): + + # We optimize the model to maximize the similarity + optimizer = optim.Adam( + model.text_model.parameters(), lr=self.aesthetic_lr + ) + + for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"): + text_embs = model.get_text_features(input_ids=tokens) + text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) + sim = text_embs @ img_embs.T + loss = -sim + optimizer.zero_grad() + loss.mean().backward() + optimizer.step() + + zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + if opts.CLIP_stop_at_last_layers > 1: + zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] + zn = model.text_model.final_layer_norm(zn) + else: + zn = zn.last_hidden_state + model.cpu() + del model + gc.collect() + torch.cuda.empty_cache() + zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1) + if self.slerp: + z = slerp(z, zn, self.aesthetic_weight) + else: + z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight + + return z diff --git a/modules/img2img.py b/modules/img2img.py index 24126774..4ed80c4b 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -56,7 +56,14 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, + aesthetic_lr=0, + aesthetic_weight=0, aesthetic_steps=0, + aesthetic_imgs=None, + aesthetic_slerp=False, + aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False, *args): is_inpaint = mode == 1 is_batch = mode == 2 @@ -109,6 +116,11 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpainting_mask_invert=inpainting_mask_invert, ) + shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), + aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative) + if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index 1db26c3e..685f9fcd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -146,7 +146,8 @@ class Processed: self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1 - self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 + self.subseed = int( + self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.all_prompts = all_prompts or [self.prompt] self.all_seeds = all_seeds or [self.seed] @@ -332,16 +333,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() -def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, - aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", - aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False) -> Processed: +def process_images(p: StableDiffusionProcessing) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - aesthetic_lr = float(aesthetic_lr) - aesthetic_weight = float(aesthetic_weight) - aesthetic_steps = int(aesthetic_steps) - if type(p.prompt) == list: assert (len(p.prompt) > 0) else: @@ -417,16 +411,10 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh # uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) # c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): - if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): - shared.sd_model.cond_stage_model.set_aesthetic_params() + shared.aesthetic_clip.set_skip(True) uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) - if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): - shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight, - aesthetic_steps, aesthetic_imgs, - aesthetic_slerp, aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative) + shared.aesthetic_clip.set_skip(False) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: @@ -582,7 +570,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) @@ -600,10 +587,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) - samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] + samples = samples[:, :, self.truncate_y // 2:samples.shape[2] - self.truncate_y // 2, + self.truncate_x // 2:samples.shape[3] - self.truncate_x // 2] if opts.use_scale_latent_for_hires_fix: - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), + mode="bilinear") else: decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5d0590af..227e7670 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -29,8 +29,8 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and ( + 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward @@ -118,33 +118,14 @@ class StableDiffusionModelHijack: return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) -def slerp(low, high, val): - low_norm = low / torch.norm(low, dim=1, keepdim=True) - high_norm = high / torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm * high_norm).sum(1)) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high - return res - - class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.clipModel = CLIPModel.from_pretrained( - self.wrapped.transformer.name_or_path - ) - del self.clipModel.vision_model - self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path) - self.hijack: StableDiffusionModelHijack = hijack - self.tokenizer = wrapped.tokenizer - # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() - self.image_embs_name = None - self.image_embs = None - self.load_image_embs(None) self.token_mults = {} - + self.hijack: StableDiffusionModelHijack = hijack + self.tokenizer = wrapped.tokenizer self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if @@ -164,28 +145,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, - aesthetic_slerp=True, aesthetic_imgs_text="", - aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False): - self.aesthetic_imgs_text = aesthetic_imgs_text - self.aesthetic_slerp_angle = aesthetic_slerp_angle - self.aesthetic_text_negative = aesthetic_text_negative - self.slerp = aesthetic_slerp - self.aesthetic_lr = aesthetic_lr - self.aesthetic_weight = aesthetic_weight - self.aesthetic_steps = aesthetic_steps - self.load_image_embs(image_embs_name) - - def load_image_embs(self, image_embs_name): - if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": - image_embs_name = None - if image_embs_name is not None and self.image_embs_name != image_embs_name: - self.image_embs_name = image_embs_name - self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) - self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) - self.image_embs.requires_grad_(False) - def tokenize_line(self, line, used_custom_terms, hijack_comments): id_end = self.wrapped.tokenizer.eos_token_id @@ -391,58 +350,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - - if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [ - [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in - remade_batch_tokens] - - tokens = torch.asarray(remade_batch_tokens).to(device) - - model = copy.deepcopy(self.clipModel).to(device) - model.requires_grad_(True) - if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: - text_embs_2 = model.get_text_features( - **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) - if self.aesthetic_text_negative: - text_embs_2 = self.image_embs - text_embs_2 - text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) - img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) - else: - img_embs = self.image_embs - - with torch.enable_grad(): - - # We optimize the model to maximize the similarity - optimizer = optim.Adam( - model.text_model.parameters(), lr=self.aesthetic_lr - ) - - for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"): - text_embs = model.get_text_features(input_ids=tokens) - text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ img_embs.T - loss = -sim - optimizer.zero_grad() - loss.mean().backward() - optimizer.step() - - zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - if opts.CLIP_stop_at_last_layers > 1: - zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] - zn = model.text_model.final_layer_norm(zn) - else: - zn = zn.last_hidden_state - model.cpu() - del model - - zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1) - if self.slerp: - z = slerp(z, zn, self.aesthetic_weight) - else: - z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight - + z = shared.aesthetic_clip(z, remade_batch_tokens) remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 diff --git a/modules/sd_models.py b/modules/sd_models.py index 3aa21ec1..8e4ee435 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -20,7 +20,7 @@ checkpoints_loaded = collections.OrderedDict() try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging + from transformers import logging, CLIPModel logging.set_verbosity_error() except Exception: @@ -196,6 +196,9 @@ def load_model(): sd_hijack.model_hijack.hijack(sd_model) + if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path: + shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path) + sd_model.eval() print(f"Model loaded.") diff --git a/modules/shared.py b/modules/shared.py index e2c98b2d..e19ca779 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -3,6 +3,7 @@ import datetime import json import os import sys +from collections import OrderedDict import gradio as gr import tqdm @@ -94,15 +95,15 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None -aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in - os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} -aesthetic_embeddings = aesthetic_embeddings | {"None": None} +aesthetic_embeddings = {} def update_aesthetic_embeddings(): global aesthetic_embeddings aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} - aesthetic_embeddings = aesthetic_embeddings | {"None": None} + aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings) + +update_aesthetic_embeddings() def reload_hypernetworks(): global hypernetworks @@ -381,6 +382,11 @@ sd_upscalers = [] sd_model = None +clip_model = None + +from modules.aesthetic_clip import AestheticCLIP +aesthetic_clip = AestheticCLIP() + progress_print_out = sys.stdout diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 68ceffe3..23bb4b6a 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -49,7 +49,7 @@ class PersonalizedBase(Dataset): print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC) + image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) except Exception: continue diff --git a/modules/txt2img.py b/modules/txt2img.py index 8f394d05..6cbc50fc 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,12 +1,17 @@ import modules.scripts -from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ + StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, cmd_opts import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int,aesthetic_lr=0, +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, + restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, + subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, + height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, + firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, @@ -41,15 +46,17 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: firstphase_height=firstphase_height if enable_hr else None, ) + shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), + aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, + aesthetic_text_negative) + if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: - processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative) + processed = process_images(p) shared.total_tqdm.clear() @@ -61,4 +68,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: processed.images = [] return processed.images, generation_info_js, plaintext_to_html(processed.info) - diff --git a/modules/ui.py b/modules/ui.py index 4069f0d2..0e5d73f0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -43,7 +43,7 @@ from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.aesthetic_clip +import modules.aesthetic_clip as aesthetic_clip import modules.images_history as img_his @@ -593,23 +593,25 @@ def create_ui(wrap_gradio_gpu_call): width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - with gr.Group(): - with gr.Accordion("Open for Clip Aesthetic!",open=False): - with gr.Row(): - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) - - with gr.Row(): - aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") - aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) - aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), - label="Aesthetic imgs embedding", - value="None") - - with gr.Row(): - aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") - aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) - aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + # with gr.Group(): + # with gr.Accordion("Open for Clip Aesthetic!",open=False): + # with gr.Row(): + # aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) + # aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) + # + # with gr.Row(): + # aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") + # aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + # aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), + # label="Aesthetic imgs embedding", + # value="None") + # + # with gr.Row(): + # aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") + # aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) + # aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + + aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui() with gr.Row(): @@ -840,6 +842,9 @@ def create_ui(wrap_gradio_gpu_call): width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui() + + with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) tiling = gr.Checkbox(label='Tiling', value=False) @@ -944,6 +949,14 @@ def create_ui(wrap_gradio_gpu_call): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, + aesthetic_lr_im, + aesthetic_weight_im, + aesthetic_steps_im, + aesthetic_imgs_im, + aesthetic_slerp_im, + aesthetic_imgs_text_im, + aesthetic_slerp_angle_im, + aesthetic_text_negative_im, ] + custom_inputs, outputs=[ img2img_gallery, @@ -1283,7 +1296,7 @@ def create_ui(wrap_gradio_gpu_call): ) create_embedding_ae.click( - fn=modules.aesthetic_clip.generate_imgs_embd, + fn=aesthetic_clip.generate_imgs_embd, inputs=[ new_embedding_name_ae, process_src_ae, @@ -1291,6 +1304,7 @@ def create_ui(wrap_gradio_gpu_call): ], outputs=[ aesthetic_imgs, + aesthetic_imgs_im, ti_output, ti_outcome, ] -- cgit v1.2.3 From c8045c5ad4f99deb3a19add06e0457de1df62b05 Mon Sep 17 00:00:00 2001 From: SGKoishi Date: Sun, 16 Oct 2022 10:08:23 -0700 Subject: The hide_ui_dir_config flag also restrict write attempt to path settings --- modules/shared.py | 10 ++++++++++ modules/ui.py | 8 +++++++- 2 files changed, 17 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index dcab0af9..c2775603 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -77,6 +77,16 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl cmd_opts = parser.parse_args() +restricted_opts = [ + "samples_filename_pattern", + "outdir_samples", + "outdir_txt2img_samples", + "outdir_img2img_samples", + "outdir_extras_samples", + "outdir_grids", + "outdir_txt2img_grids", + "outdir_save", +] devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer']) diff --git a/modules/ui.py b/modules/ui.py index 7b0d5a92..43dc88fc 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -25,7 +25,7 @@ import gradio.routes from modules import sd_hijack, sd_models from modules.paths import script_path -from modules.shared import opts, cmd_opts +from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags import modules.shared as shared @@ -1430,6 +1430,9 @@ Requested path was: {f} if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: continue + if cmd_opts.hide_ui_dir_config and key in restricted_opts: + continue + oldval = opts.data.get(key, None) opts.data[key] = value @@ -1447,6 +1450,9 @@ Requested path was: {f} if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() + if cmd_opts.hide_ui_dir_config and key in restricted_opts: + return gr.update(value=oldval), opts.dumpjson() + oldval = opts.data.get(key, None) opts.data[key] = value -- cgit v1.2.3 From 0fd130767102ebcf90e97c6c191ecf199a2d4091 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Sun, 16 Oct 2022 18:44:39 -0400 Subject: improve performance of 3-way merge on machines with not enough ram, by only accessing two of the models at a time --- modules/extras.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 0819ed37..340a45fd 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -175,11 +175,14 @@ def run_pnginfo(image): def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name): - def weighted_sum(theta0, theta1, theta2, alpha): + def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) - def add_difference(theta0, theta1, theta2, alpha): - return theta0 + (theta1 - theta2) * alpha + def get_difference(theta1, theta2): + return theta1 - theta2 + + def add_difference(theta0, theta1_2_diff, alpha): + return theta0 + (alpha * theta1_2_diff) primary_model_info = sd_models.checkpoints_list[primary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name] @@ -201,20 +204,24 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam theta_2 = None theta_funcs = { - "Weighted sum": weighted_sum, - "Add difference": add_difference, + "Weighted sum": (None, weighted_sum), + "Add difference": (get_difference, add_difference), } - theta_func = theta_funcs[interp_method] + theta_func1, theta_func2 = theta_funcs[interp_method] print(f"Merging...") + if theta_func1: + for key in tqdm.tqdm(theta_1.keys()): + if 'model' in key: + t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) + theta_1[key] = theta_func1(theta_1[key], t2) + del theta_2, teritary_model + for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: - t2 = (theta_2 or {}).get(key) - if t2 is None: - t2 = torch.zeros_like(theta_0[key]) - theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier) + theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier) if save_as_half: theta_0[key] = theta_0[key].half() -- cgit v1.2.3 From 6f7b7a3dcdc471ebe63baa8a7731952557859c5b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 07:56:08 +0300 Subject: only read files with .py extension from the scripts dir --- modules/scripts.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 45230f9a..ac66d448 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -58,6 +58,9 @@ def load_scripts(basedir): for filename in sorted(os.listdir(basedir)): path = os.path.join(basedir, filename) + if os.path.splitext(path)[1].lower() != '.py': + continue + if not os.path.isfile(path): continue -- cgit v1.2.3 From 58f3ef77336663bce2321f5b692cf2aeacd3ac1c Mon Sep 17 00:00:00 2001 From: DenkingOfficial Date: Mon, 17 Oct 2022 03:10:59 +0500 Subject: Fix CLIP Interrogator and disable ranks for it --- modules/interrogate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index 9263d65a..d85d7dcc 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -157,9 +157,9 @@ class InterrogateModels: matches = self.rank(image_features, items, top_count=topn) for match, score in matches: if include_ranks: - res += ", " + match - else: res += f", ({match}:{score})" + else: + res += ", " + match except Exception: print(f"Error interrogating", file=sys.stderr) -- cgit v1.2.3 From 5c94aaf290f8ad7bf4499a91c268ad0791b0432f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 08:28:18 +0300 Subject: fix bug for latest model merge RAM improvement --- modules/extras.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 340a45fd..8dbab240 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -201,6 +201,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam teritary_model = torch.load(teritary_model_info.filename, map_location='cpu') theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model) else: + teritary_model = None theta_2 = None theta_funcs = { -- cgit v1.2.3 From b99d3cf6dd9bc817e51d0d0a6e8eb12c7c0ac6af Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 08:41:02 +0300 Subject: make CLIP interrogate ranks output sane values --- modules/interrogate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index d85d7dcc..64b91eb4 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -123,7 +123,7 @@ class InterrogateModels: return caption[0] - def interrogate(self, pil_image, include_ranks=False): + def interrogate(self, pil_image): res = None try: @@ -156,8 +156,8 @@ class InterrogateModels: for name, topn, items in self.categories: matches = self.rank(image_features, items, top_count=topn) for match, score in matches: - if include_ranks: - res += f", ({match}:{score})" + if shared.opts.interrogate_return_ranks: + res += f", ({match}:{score/100:.3f})" else: res += ", " + match -- cgit v1.2.3 From 62edfae257e8982cd620d03862c7bdd44159d18f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 16 Oct 2022 20:28:15 +0100 Subject: print list of embeddings on reload --- modules/textual_inversion/textual_inversion.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7ec75018..3be69562 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -137,6 +137,7 @@ class EmbeddingDatabase: continue print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + print("Embeddings:", ', '.join(self.word_embeddings.keys())) def find_embedding_at_position(self, tokens, offset): token = tokens[offset] -- cgit v1.2.3 From 9d702b16f01795c3af900e0ebd70faf4b25200f6 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 17 Oct 2022 16:11:03 +0800 Subject: fix two little bug --- modules/images_history.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 23045df1..1ae168ca 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -133,7 +133,7 @@ def archive_images(dir_name, date_to): date = sort_array[loads_num][2] filenames = [x[1] for x in sort_array] else: - date = sort_array[loads_num][2] + date = sort_array[-1][2] filenames = [x[1] for x in sort_array] filenames = [x[1] for x in sort_array if x[2]>= date] _, image_list, _, visible_num = get_recent_images(1, 0, filenames) @@ -334,6 +334,6 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): with gr.Tab(tab): with gr.Blocks(analytics_enabled=False) as images_history_img2img: show_images_history(gr, opts, tab, run_pnginfo, switch_dict) - gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory") #, visible=False) + gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False) return images_history -- cgit v1.2.3 From 60251c9456f5472784862896c2f97e38feb42482 Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Mon, 17 Oct 2022 06:58:42 +0000 Subject: initial prototype by borrowing contracts --- modules/api/api.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++ modules/processing.py | 2 +- modules/shared.py | 2 +- 3 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 modules/api/api.py (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py new file mode 100644 index 00000000..9d7c699d --- /dev/null +++ b/modules/api/api.py @@ -0,0 +1,60 @@ +from modules.api.processing import StableDiffusionProcessingAPI +from modules.processing import StableDiffusionProcessingTxt2Img, process_images +import modules.shared as shared +import uvicorn +from fastapi import FastAPI, Body, APIRouter +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field, Json +import json +import io +import base64 + +app = FastAPI() + +class TextToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: Json + info: Json + + +class Api: + def __init__(self, txt2img, img2img, run_extras, run_pnginfo): + self.router = APIRouter() + app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) + + def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): + print(txt2imgreq) + p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq)) + p.sd_model = shared.sd_model + print(p) + processed = process_images(p) + + b64images = [] + for i in processed.images: + buffer = io.BytesIO() + i.save(buffer, format="png") + b64images.append(base64.b64encode(buffer.getvalue())) + + response = { + "images": b64images, + "info": processed.js(), + "parameters": json.dumps(vars(txt2imgreq)) + } + + + return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) + + + + def img2imgendoint(self): + raise NotImplementedError + + def extrasendoint(self): + raise NotImplementedError + + def pnginfoendoint(self): + raise NotImplementedError + + def launch(self, server_name, port): + app.include_router(self.router) + uvicorn.run(app, host=server_name, port=port) \ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index deb6125e..4a7c6ccc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -723,4 +723,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): del x devices.torch_gc() - return samples + return samples \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index c2775603..6c6405fd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,7 +74,7 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) - +parser.add_argument("--api", action='store_true', help="use api=True to launch the api instead of the webui") cmd_opts = parser.parse_args() restricted_opts = [ -- cgit v1.2.3 From 9e02812afd10582f00a7fbbfa63c8f9188678e26 Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Mon, 17 Oct 2022 07:02:08 +0000 Subject: pydantic instrumentation --- modules/api/processing.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 modules/api/processing.py (limited to 'modules') diff --git a/modules/api/processing.py b/modules/api/processing.py new file mode 100644 index 00000000..459a8f49 --- /dev/null +++ b/modules/api/processing.py @@ -0,0 +1,99 @@ +from inflection import underscore +from typing import Any, Dict, Optional +from pydantic import BaseModel, Field, create_model +from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +import inspect + + +class ModelDef(BaseModel): + """Assistance Class for Pydantic Dynamic Model Generation""" + + field: str + field_alias: str + field_type: Any + field_value: Any + + +class pydanticModelGenerator: + """ + Takes source_data:Dict ( a single instance example of something like a JSON node) and self generates a pythonic data model with Alias to original source field names. This makes it easy to popuate or export to other systems yet handle the data in a pythonic way. + Being a pydantic datamodel all the richness of pydantic data validation is available and these models can easily be used in FastAPI and or a ORM + + It does not process full JSON data structures but takes simple JSON document with basic elements + + Provide a model_name, an example of JSON data and a dict of type overrides + + Example: + + source_data = {'Name': '48 Rainbow Rd', + 'GroupAddressStyle': 'ThreeLevel', + 'LastModified': '2020-12-21T07:02:51.2400232Z', + 'ProjectStart': '2020-12-03T07:36:03.324856Z', + 'Comment': '', + 'CompletionStatus': 'Editing', + 'LastUsedPuid': '955', + 'Guid': '0c85957b-c2ae-4985-9752-b300ab385b36'} + + source_overrides = {'Guid':{'type':uuid.UUID}, + 'LastModified':{'type':datetime }, + 'ProjectStart':{'type':datetime }, + } + source_optionals = {"Comment":True} + + #create Model + model_Project=pydanticModelGenerator( + model_name="Project", + source_data=source_data, + overrides=source_overrides, + optionals=source_optionals).generate_model() + + #create instance using DynamicModel + project_instance=model_Project(**project_info) + + """ + + def __init__( + self, + model_name: str = None, + source_data: str = None, + params: Dict = {}, + overrides: Dict = {}, + optionals: Dict = {}, + ): + def field_type_generator(k, v, overrides, optionals): + print(k, v) + field_type = str if not overrides.get(k) else overrides[k]["type"] + if v is None: + field_type = Any + else: + field_type = type(v) + + return Optional[field_type] + + self._model_name = model_name + self._json_data = source_data + self._model_def = [ + ModelDef( + field=underscore(k), + field_alias=k, + field_type=field_type_generator(k, v, overrides, optionals), + field_value=v + ) + for (k,v) in source_data.items() if k in params + ] + + def generate_model(self): + """ + Creates a pydantic BaseModel + from the json and overrides provided at initialization + """ + fields = { + d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def + } + DynamicModel = create_model(self._model_name, **fields) + DynamicModel.__config__.allow_population_by_field_name = True + return DynamicModel + +StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing", + StableDiffusionProcessing().__dict__, + inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model() \ No newline at end of file -- cgit v1.2.3 From 832b490e5173f78c4d3aa7ca9ca9ac794d140664 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 17 Oct 2022 03:18:41 -0400 Subject: Update processing.py --- modules/api/processing.py | 41 +++++------------------------------------ 1 file changed, 5 insertions(+), 36 deletions(-) (limited to 'modules') diff --git a/modules/api/processing.py b/modules/api/processing.py index 459a8f49..4c3d0bd0 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -16,46 +16,15 @@ class ModelDef(BaseModel): class pydanticModelGenerator: """ - Takes source_data:Dict ( a single instance example of something like a JSON node) and self generates a pythonic data model with Alias to original source field names. This makes it easy to popuate or export to other systems yet handle the data in a pythonic way. - Being a pydantic datamodel all the richness of pydantic data validation is available and these models can easily be used in FastAPI and or a ORM - - It does not process full JSON data structures but takes simple JSON document with basic elements - - Provide a model_name, an example of JSON data and a dict of type overrides - - Example: - - source_data = {'Name': '48 Rainbow Rd', - 'GroupAddressStyle': 'ThreeLevel', - 'LastModified': '2020-12-21T07:02:51.2400232Z', - 'ProjectStart': '2020-12-03T07:36:03.324856Z', - 'Comment': '', - 'CompletionStatus': 'Editing', - 'LastUsedPuid': '955', - 'Guid': '0c85957b-c2ae-4985-9752-b300ab385b36'} - - source_overrides = {'Guid':{'type':uuid.UUID}, - 'LastModified':{'type':datetime }, - 'ProjectStart':{'type':datetime }, - } - source_optionals = {"Comment":True} - - #create Model - model_Project=pydanticModelGenerator( - model_name="Project", - source_data=source_data, - overrides=source_overrides, - optionals=source_optionals).generate_model() - - #create instance using DynamicModel - project_instance=model_Project(**project_info) - + Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: + source_data is a snapshot of the default values produced by the class + params are the names of the actual keys required by __init__ """ def __init__( self, model_name: str = None, - source_data: str = None, + source_data: {} = {}, params: Dict = {}, overrides: Dict = {}, optionals: Dict = {}, @@ -96,4 +65,4 @@ class pydanticModelGenerator: StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing", StableDiffusionProcessing().__dict__, - inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model() \ No newline at end of file + inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model() -- cgit v1.2.3 From 99013ba68a5fe1bde3621632e5539c03562a3ae8 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 17 Oct 2022 03:20:17 -0400 Subject: Update processing.py --- modules/api/processing.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/api/processing.py b/modules/api/processing.py index 4c3d0bd0..e4df93c5 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -30,7 +30,6 @@ class pydanticModelGenerator: optionals: Dict = {}, ): def field_type_generator(k, v, overrides, optionals): - print(k, v) field_type = str if not overrides.get(k) else overrides[k]["type"] if v is None: field_type = Any -- cgit v1.2.3 From 71d42bb44b257f3fb274c3ad5075a195281ff915 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 17 Oct 2022 03:22:19 -0400 Subject: Update api.py --- modules/api/api.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 9d7c699d..4d9619a8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -23,10 +23,8 @@ class Api: app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): - print(txt2imgreq) p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq)) p.sd_model = shared.sd_model - print(p) processed = process_images(p) b64images = [] @@ -34,13 +32,6 @@ class Api: buffer = io.BytesIO() i.save(buffer, format="png") b64images.append(base64.b64encode(buffer.getvalue())) - - response = { - "images": b64images, - "info": processed.js(), - "parameters": json.dumps(vars(txt2imgreq)) - } - return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) @@ -57,4 +48,4 @@ class Api: def launch(self, server_name, port): app.include_router(self.router) - uvicorn.run(app, host=server_name, port=port) \ No newline at end of file + uvicorn.run(app, host=server_name, port=port) -- cgit v1.2.3 From d42125baf62880854ad06af06c15c23e7e50cca6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 11:50:20 +0300 Subject: add missing requirement for api and fix some typos --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 4d9619a8..fd09d352 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,7 @@ class TextToImageResponse(BaseModel): class Api: - def __init__(self, txt2img, img2img, run_extras, run_pnginfo): + def __init__(self): self.router = APIRouter() app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) -- cgit v1.2.3 From 8c6a981d5d9ef30381ac2327460285111550acbc Mon Sep 17 00:00:00 2001 From: Michoko Date: Mon, 17 Oct 2022 11:05:05 +0200 Subject: Added dark mode switch Launch the UI in dark mode with the --dark-mode switch --- modules/shared.py | 2 +- modules/ui.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index c2775603..cbf158e4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -69,13 +69,13 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) +parser.add_argument("--dark-mode", action='store_true', help="launches the UI in dark mode", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) - cmd_opts = parser.parse_args() restricted_opts = [ "samples_filename_pattern", diff --git a/modules/ui.py b/modules/ui.py index 43dc88fc..a0cd052e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1783,6 +1783,8 @@ for filename in sorted(os.listdir(jsdir)): with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: javascript += f"\n" +if cmd_opts.dark_mode: + javascript += "\n\n" if 'gradio_routes_templates_response' not in globals(): def template_response(*args, **kwargs): -- cgit v1.2.3 From c408a0b41cfffde184cad35b2d97346342947d83 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 17 Oct 2022 22:28:43 +0800 Subject: fix two bug --- modules/images_history.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 1ae168ca..10e5b970 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -181,7 +181,8 @@ def delete_image(delete_num, name, filenames, image_index, visible_num): return new_file_list, 1, visible_num def save_image(file_name): - shutil.copy2(file_name, opts.outdir_save) + if file_name is not None and os.path.exists(file_name): + shutil.copy2(file_name, opts.outdir_save) def get_recent_images(page_index, step, filenames): page_index = int(page_index) @@ -327,7 +328,6 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): opts = sys_opts loads_files_num = int(opts.images_history_num_per_page) num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) - backup_flag = opts.images_history_backup with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: for tab in ["txt2img", "img2img", "extras", "saved"]: -- cgit v1.2.3 From 2272cf2f35fafd5cd486bfb4ee89df5bbc625b97 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 17 Oct 2022 23:04:42 +0800 Subject: fix two bug --- modules/images_history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 10e5b970..1c1790a4 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -133,7 +133,7 @@ def archive_images(dir_name, date_to): date = sort_array[loads_num][2] filenames = [x[1] for x in sort_array] else: - date = sort_array[-1][2] + date = None if len(sort_array) == 0 else sort_array[-1][2] filenames = [x[1] for x in sort_array] filenames = [x[1] for x in sort_array if x[2]>= date] _, image_list, _, visible_num = get_recent_images(1, 0, filenames) -- cgit v1.2.3 From 2b5b62e768d892773a7ec1d5e8d8cea23aae1254 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 17 Oct 2022 23:14:03 +0800 Subject: fix two bug --- modules/images_history.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 1c1790a4..20324557 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -44,7 +44,7 @@ def traverse_all_files(curr_path, image_list, all_type=False): return image_list for file in f_list: file = os.path.join(curr_path, file) - if (not all_type) and file[-4:] == ".txt": + if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"): pass elif os.path.isfile(file) and file[-10:].rfind(".") > 0: image_list.append(file) @@ -182,7 +182,7 @@ def delete_image(delete_num, name, filenames, image_index, visible_num): def save_image(file_name): if file_name is not None and os.path.exists(file_name): - shutil.copy2(file_name, opts.outdir_save) + shutil.copy(file_name, opts.outdir_save) def get_recent_images(page_index, step, filenames): page_index = int(page_index) -- cgit v1.2.3 From 665beebc0825a6fad410c8252f27f6f6f0bd900b Mon Sep 17 00:00:00 2001 From: Michoko Date: Mon, 17 Oct 2022 18:24:24 +0200 Subject: Use of a --theme argument for more flexibility Added possibility to set the theme (light or dark) --- modules/shared.py | 2 +- modules/ui.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index cbf158e4..fa084c69 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -69,7 +69,7 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) -parser.add_argument("--dark-mode", action='store_true', help="launches the UI in dark mode", default=False) +parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) diff --git a/modules/ui.py b/modules/ui.py index a0cd052e..d41715fa 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1783,8 +1783,8 @@ for filename in sorted(os.listdir(jsdir)): with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: javascript += f"\n" -if cmd_opts.dark_mode: - javascript += "\n\n" +if cmd_opts.theme is not None: + javascript += f"\n\n" if 'gradio_routes_templates_response' not in globals(): def template_response(*args, **kwargs): -- cgit v1.2.3 From 695377a8b9f7de28b880d96487a9ddf7230cff14 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 19:56:23 +0300 Subject: make modelmerger work with ui-config.json --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 43dc88fc..533b1db3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1767,6 +1767,7 @@ Requested path was: {f} visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") visit(extras_interface, loadsave, "extras") + visit(modelmerger_interface, loadsave, "modelmerger") if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): with open(ui_config_file, "w", encoding="utf8") as file: -- cgit v1.2.3 From cf47d13c1e11fcb7169bac7488d2c39e579ee491 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 21:15:32 +0300 Subject: localization support --- modules/localization.py | 31 +++++++++++++++++++++++++++++++ modules/shared.py | 7 +++++-- modules/ui.py | 33 +++++++++++++++++++++++---------- 3 files changed, 59 insertions(+), 12 deletions(-) create mode 100644 modules/localization.py (limited to 'modules') diff --git a/modules/localization.py b/modules/localization.py new file mode 100644 index 00000000..b1810cda --- /dev/null +++ b/modules/localization.py @@ -0,0 +1,31 @@ +import json +import os +import sys +import traceback + +localizations = {} + + +def list_localizations(dirname): + localizations.clear() + + for file in os.listdir(dirname): + fn, ext = os.path.splitext(file) + if ext.lower() != ".json": + continue + + localizations[fn] = os.path.join(dirname, file) + + +def localization_js(current_localization_name): + fn = localizations.get(current_localization_name, None) + data = {} + if fn is not None: + try: + with open(fn, "r", encoding="utf8") as file: + data = json.load(file) + except Exception: + print(f"Error loading localization from {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + return f"var localization = {json.dumps(data)}\n" diff --git a/modules/shared.py b/modules/shared.py index c2775603..2a2b0427 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models +from modules import sd_samplers, sd_models, localization from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -31,6 +31,7 @@ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") +parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") @@ -103,7 +104,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None - def reload_hypernetworks(): global hypernetworks @@ -151,6 +151,8 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] +localization.list_localizations(cmd_opts.localizations_dir) + def realesrgan_models_names(): import modules.realesrgan_model @@ -296,6 +298,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), + 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { diff --git a/modules/ui.py b/modules/ui.py index 533b1db3..656bab7a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,7 +23,7 @@ import gradio as gr import gradio.utils import gradio.routes -from modules import sd_hijack, sd_models +from modules import sd_hijack, sd_models, localization from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: @@ -1056,10 +1056,10 @@ def create_ui(wrap_gradio_gpu_call): upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2 = gr.Radio(label='Upscaler 2', celem_id="extras_upscaler_2", hoices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) with gr.Group(): @@ -1224,10 +1224,10 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") with gr.Row(): - train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") with gr.Row(): - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") batch_size = gr.Number(label='Batch size', value=1, precision=0) @@ -1376,16 +1376,18 @@ def create_ui(wrap_gradio_gpu_call): else: raise Exception(f'bad options item type: {str(t)} for key {key}') + elem_id = "setting_"+key + if info.refresh is not None: if is_quicksettings: - res = comp(label=info.label, value=fun, **(args or {})) - refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: with gr.Row(variant="compact"): - res = comp(label=info.label, value=fun, **(args or {})) - refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: - res = comp(label=info.label, value=fun, **(args or {})) + res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) return res @@ -1509,6 +1511,9 @@ Requested path was: {f} with gr.Row(): request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + + with gr.Row(): reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') @@ -1519,6 +1524,13 @@ Requested path was: {f} _js='function(){}' ) + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + def reload_scripts(): modules.scripts.reload_script_body_only() @@ -1784,6 +1796,7 @@ for filename in sorted(os.listdir(jsdir)): with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: javascript += f"\n" +javascript += f"\n" if 'gradio_routes_templates_response' not in globals(): def template_response(*args, **kwargs): -- cgit v1.2.3 From d62ef76614624cda99d842a2900242d5b7923eda Mon Sep 17 00:00:00 2001 From: guaneec Date: Tue, 18 Oct 2022 03:09:50 +0800 Subject: Don't eat colons in booru tags --- modules/deepbooru.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 4ad334a1..de16b13f 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -157,8 +157,6 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o # sort by reverse by likelihood and normal for alpha, and format tag text as requested unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) for weight, tag in unsorted_tags_in_theshold: - # note: tag_outformat will still have a colon if include_ranks is True - tag_outformat = tag.replace(':', ' ') if use_spaces: tag_outformat = tag_outformat.replace('_', ' ') if use_escape: -- cgit v1.2.3 From f80e914ac4aa69a9783b4040813253500b34d925 Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Mon, 17 Oct 2022 19:10:36 +0000 Subject: example API working with gradio --- modules/api/api.py | 9 ++++++-- modules/api/processing.py | 56 ++++++++++++++++++++++++++++++++--------------- modules/processing.py | 22 +++++++++++++------ 3 files changed, 60 insertions(+), 27 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index fd09d352..5e86c3bf 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -23,8 +23,13 @@ class Api: app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): - p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq)) - p.sd_model = shared.sd_model + populate = txt2imgreq.copy(update={ # Override __init__ params + "sd_model": shared.sd_model, + "sampler_index": 0, + } + ) + p = StableDiffusionProcessingTxt2Img(**vars(populate)) + # Override object param processed = process_images(p) b64images = [] diff --git a/modules/api/processing.py b/modules/api/processing.py index e4df93c5..b6798241 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -5,6 +5,24 @@ from modules.processing import StableDiffusionProcessing, Processed, StableDiffu import inspect +API_NOT_ALLOWED = [ + "self", + "kwargs", + "sd_model", + "outpath_samples", + "outpath_grids", + "sampler_index", + "do_not_save_samples", + "do_not_save_grid", + "extra_generation_params", + "overlay_images", + "do_not_reload_embeddings", + "seed_enable_extras", + "prompt_for_display", + "sampler_noise_scheduler_override", + "ddim_discretize" +] + class ModelDef(BaseModel): """Assistance Class for Pydantic Dynamic Model Generation""" @@ -14,7 +32,7 @@ class ModelDef(BaseModel): field_value: Any -class pydanticModelGenerator: +class PydanticModelGenerator: """ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: source_data is a snapshot of the default values produced by the class @@ -24,30 +42,33 @@ class pydanticModelGenerator: def __init__( self, model_name: str = None, - source_data: {} = {}, - params: Dict = {}, - overrides: Dict = {}, - optionals: Dict = {}, + class_instance = None ): - def field_type_generator(k, v, overrides, optionals): - field_type = str if not overrides.get(k) else overrides[k]["type"] - if v is None: - field_type = Any - else: - field_type = type(v) + def field_type_generator(k, v): + # field_type = str if not overrides.get(k) else overrides[k]["type"] + # print(k, v.annotation, v.default) + field_type = v.annotation return Optional[field_type] + def merge_class_params(class_): + all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) + parameters = {} + for classes in all_classes: + parameters = {**parameters, **inspect.signature(classes.__init__).parameters} + return parameters + + self._model_name = model_name - self._json_data = source_data + self._class_data = merge_class_params(class_instance) self._model_def = [ ModelDef( field=underscore(k), field_alias=k, - field_type=field_type_generator(k, v, overrides, optionals), - field_value=v + field_type=field_type_generator(k, v), + field_value=v.default ) - for (k,v) in source_data.items() if k in params + for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] def generate_model(self): @@ -60,8 +81,7 @@ class pydanticModelGenerator: } DynamicModel = create_model(self._model_name, **fields) DynamicModel.__config__.allow_population_by_field_name = True + DynamicModel.__config__.allow_mutation = True return DynamicModel -StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing", - StableDiffusionProcessing().__dict__, - inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model() +StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model() diff --git a/modules/processing.py b/modules/processing.py index 4a7c6ccc..024a4fc3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps import random import cv2 from skimage import exposure +from typing import Any, Dict, List, Optional import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram @@ -51,9 +52,15 @@ def get_correct_sampler(p): return sd_samplers.samplers elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): return sd_samplers.samplers_for_img2img + elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI): + return sd_samplers.samplers -class StableDiffusionProcessing: - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None, do_not_reload_embeddings=False): +class StableDiffusionProcessing(): + """ + The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing + + """ + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -86,10 +93,10 @@ class StableDiffusionProcessing: self.denoising_strength: float = 0 self.sampler_noise_scheduler_override = None self.ddim_discretize = opts.ddim_discretize - self.s_churn = opts.s_churn - self.s_tmin = opts.s_tmin - self.s_tmax = float('inf') # not representable as a standard ui option - self.s_noise = opts.s_noise + self.s_churn = s_churn or opts.s_churn + self.s_tmin = s_tmin or opts.s_tmin + self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option + self.s_noise = s_noise or opts.s_noise if not seed_enable_extras: self.subseed = -1 @@ -97,6 +104,7 @@ class StableDiffusionProcessing: self.seed_resize_from_h = 0 self.seed_resize_from_w = 0 + def init(self, all_prompts, all_seeds, all_subseeds): pass @@ -497,7 +505,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs): + def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength -- cgit v1.2.3 From 2e28c841f438b2090caac2b9a54eb62ddbda837c Mon Sep 17 00:00:00 2001 From: guaneec Date: Tue, 18 Oct 2022 03:15:41 +0800 Subject: Oops --- modules/deepbooru.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index de16b13f..8914662d 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -157,6 +157,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o # sort by reverse by likelihood and normal for alpha, and format tag text as requested unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) for weight, tag in unsorted_tags_in_theshold: + tag_outformat = tag if use_spaces: tag_outformat = tag_outformat.replace('_', ' ') if use_escape: -- cgit v1.2.3 From f29b16bad19b6332a15b2ef439864d866277fffb Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Mon, 17 Oct 2022 20:36:14 +0000 Subject: prevent API from saving --- modules/api/api.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5e86c3bf..ce72c5ee 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -26,6 +26,8 @@ class Api: populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, "sampler_index": 0, + "do_not_save_samples": True, + "do_not_save_grid": True } ) p = StableDiffusionProcessingTxt2Img(**vars(populate)) -- cgit v1.2.3 From ab3f997c0c4a1423a82623ae1d4d3c66005bb8da Mon Sep 17 00:00:00 2001 From: Jordan Hall Date: Mon, 17 Oct 2022 20:59:44 +0100 Subject: Fix typo in 'choices' when loading upscaler 2 config --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 656bab7a..e4ead347 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1059,7 +1059,7 @@ def create_ui(wrap_gradio_gpu_call): extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', celem_id="extras_upscaler_2", hoices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2 = gr.Radio(label='Upscaler 2', celem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) with gr.Group(): -- cgit v1.2.3 From d3338bdef18b3049431a0649d55ff22aa18baa68 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 17 Oct 2022 22:46:56 +0100 Subject: extras extend cache key with new upscale to options --- modules/extras.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 8dbab240..c908b43e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -91,7 +91,8 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) pixels = tuple(np.array(small).flatten().tolist()) - key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels + key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight, + resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels c = cached_images.get(key) if c is None: -- cgit v1.2.3 From c3851a853d99ad35ccedcdd8dbeb6cfbe273439b Mon Sep 17 00:00:00 2001 From: Ryan Voots Date: Mon, 17 Oct 2022 12:49:33 -0400 Subject: Re-use webui fastapi application rather than requiring one or the other, not both. --- modules/api/api.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index ce72c5ee..8781cd86 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -2,15 +2,13 @@ from modules.api.processing import StableDiffusionProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, process_images import modules.shared as shared import uvicorn -from fastapi import FastAPI, Body, APIRouter +from fastapi import Body, APIRouter from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, Json import json import io import base64 -app = FastAPI() - class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: Json @@ -18,7 +16,7 @@ class TextToImageResponse(BaseModel): class Api: - def __init__(self): + def __init__(self, app): self.router = APIRouter() app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) -- cgit v1.2.3 From 247aeb3aaaf2925c7d68a9cf47c975f3e6d3dd33 Mon Sep 17 00:00:00 2001 From: Ryan Voots Date: Mon, 17 Oct 2022 12:50:45 -0400 Subject: Put API under /sdapi/ so that routing is simpler in the future. This means that one could allow access to /sdapi/ but not the webui. --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 8781cd86..14613d8c 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,7 @@ class TextToImageResponse(BaseModel): class Api: def __init__(self, app): self.router = APIRouter() - app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) + app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): populate = txt2imgreq.copy(update={ # Override __init__ params -- cgit v1.2.3 From 1df3ff25e6fe2e3f308e45f7a6dd37fb4f1988e6 Mon Sep 17 00:00:00 2001 From: Ryan Voots Date: Mon, 17 Oct 2022 12:58:34 -0400 Subject: Add --nowebui as a means of disabling the webui and run on the other port --- modules/shared.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 6c6405fd..8b436970 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,7 +74,8 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) -parser.add_argument("--api", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") +parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") cmd_opts = parser.parse_args() restricted_opts = [ -- cgit v1.2.3 From 7432b6f4d2c3001895fc75411a34afae1810c1a2 Mon Sep 17 00:00:00 2001 From: Mykeehu Date: Tue, 18 Oct 2022 07:15:38 +0200 Subject: Fix typo "celem_id" to "elem_id" --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index e4ead347..2a7f64f9 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1059,7 +1059,7 @@ def create_ui(wrap_gradio_gpu_call): extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', celem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) with gr.Group(): -- cgit v1.2.3 From 8d5d863a9d11850464fdb6b64f34602803c15ccc Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Tue, 18 Oct 2022 06:51:53 +0000 Subject: gradio and FastAPI --- modules/api/api.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 14613d8c..ce98cb8c 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -16,9 +16,11 @@ class TextToImageResponse(BaseModel): class Api: - def __init__(self, app): + def __init__(self, app, queue_lock): self.router = APIRouter() - app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) + self.app = app + self.queue_lock = queue_lock + self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): populate = txt2imgreq.copy(update={ # Override __init__ params @@ -30,7 +32,8 @@ class Api: ) p = StableDiffusionProcessingTxt2Img(**vars(populate)) # Override object param - processed = process_images(p) + with self.queue_lock: + processed = process_images(p) b64images = [] for i in processed.images: @@ -52,5 +55,5 @@ class Api: raise NotImplementedError def launch(self, server_name, port): - app.include_router(self.router) - uvicorn.run(app, host=server_name, port=port) + self.app.include_router(self.router) + uvicorn.run(self.app, host=server_name, port=port) -- cgit v1.2.3 From 786ed499226177d71e937e0342bcb9d3b1ff260f Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 17 Oct 2022 19:48:39 +0300 Subject: use legacy attnblock --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 984b35c4..2407a461 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -27,7 +27,7 @@ def apply_optimizations(): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 -- cgit v1.2.3 From 2043c4a231eef838bb15044f502b864b55885037 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 17 Oct 2022 19:49:11 +0300 Subject: delete xformers attnblock --- modules/sd_hijack_optimizations.py | 12 ------------ 1 file changed, 12 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 79405525..60da7459 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -292,15 +292,3 @@ def cross_attention_attnblock_forward(self, x): return h3 -def xformers_attnblock_forward(self, x): - try: - h_ = x - h_ = self.norm(h_) - q1 = self.q(h_).contiguous() - k1 = self.k(h_).contiguous() - v = self.v(h_).contiguous() - out = xformers.ops.memory_efficient_attention(q1, k1, v) - out = self.proj_out(out) - return x + out - except NotImplementedError: - return cross_attention_attnblock_forward(self, x) -- cgit v1.2.3 From 84823275e896bcc1f7cb4ce098ae3c5d05e17b9a Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 17 Oct 2022 22:18:59 +0300 Subject: readd xformers attnblock --- modules/sd_hijack_optimizations.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 60da7459..7ebef3f0 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -292,3 +292,18 @@ def cross_attention_attnblock_forward(self, x): return h3 +def xformers_attnblock_forward(self, x): + try: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + out = xformers.ops.memory_efficient_attention(q, k, v) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out + except NotImplementedError: + return cross_attention_attnblock_forward(self, x) -- cgit v1.2.3 From 73b5dbf72a93b64445551c74a4c0dc924986081d Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 17 Oct 2022 22:19:18 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2407a461..984b35c4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -27,7 +27,7 @@ def apply_optimizations(): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 -- cgit v1.2.3 From c71008c74156635558bb2e877d1628913f6f781e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Tue, 18 Oct 2022 00:02:50 +0300 Subject: Update sd_hijack_optimizations.py --- modules/sd_hijack_optimizations.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 7ebef3f0..a3345bb9 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -301,6 +301,9 @@ def xformers_attnblock_forward(self, x): v = self.v(h_) b, c, h, w = q.shape q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() out = xformers.ops.memory_efficient_attention(q, k, v) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) -- cgit v1.2.3 From 8b02662215917d39f76f86b703a322818d5a8ad4 Mon Sep 17 00:00:00 2001 From: trufty Date: Mon, 17 Oct 2022 10:58:21 -0400 Subject: Disable auto weights swap with config option --- modules/shared.py | 1 + modules/ui.py | 4 ++++ 2 files changed, 5 insertions(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 9603d26e..8a1d1881 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -266,6 +266,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), + "disable_weights_auto_swap": OptionInfo(False, "Disable auto swapping weights to match model hash in prompts"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), diff --git a/modules/ui.py b/modules/ui.py index 1dae4a65..75eb0b0c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -542,6 +542,10 @@ def apply_setting(key, value): if value is None: return gr.update() + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + if key == "sd_model_checkpoint": ckpt_info = sd_models.get_closet_checkpoint_match(value) -- cgit v1.2.3 From d2f459c5cf9f728256775dc1c3380c7e9a7e27fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 18 Oct 2022 14:22:52 +0300 Subject: clarify the comment for the new option from #2959 and move it to UI section. --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 8a1d1881..c0d87168 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -266,7 +266,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), - "disable_weights_auto_swap": OptionInfo(False, "Disable auto swapping weights to match model hash in prompts"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), @@ -294,6 +293,7 @@ options_templates.update(options_section(('ui', "User interface"), { "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), + "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), -- cgit v1.2.3 From 97d3ba3941536215ea15431886c7f28300a9d915 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=B5=E3=81=81?= <34892635+fa0311@users.noreply.github.com> Date: Tue, 18 Oct 2022 17:29:42 +0900 Subject: Add scripts to ui-config,json --- modules/scripts.py | 15 +++++++++++++-- modules/ui.py | 5 +++++ 2 files changed, 18 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index ac66d448..3402066d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -96,6 +96,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): class ScriptRunner: def __init__(self): self.scripts = [] + self.titles = [] def setup_ui(self, is_img2img): for script_class, path in scripts_data: @@ -107,9 +108,10 @@ class ScriptRunner: self.scripts.append(script) - titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] + self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] - dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index") + dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") + dropdown.save_to_config = True inputs = [dropdown] for script in self.scripts: @@ -139,6 +141,15 @@ class ScriptRunner: return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] + def init_field(title): + if title == "None": + return + script_index = self.titles.index(title) + script = self.scripts[script_index] + for i in range(script.args_from, script.args_to): + inputs[i].visible = True + + dropdown.init_field = init_field dropdown.change( fn=select_script, inputs=[dropdown], diff --git a/modules/ui.py b/modules/ui.py index 75eb0b0c..39afbc4e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1753,6 +1753,11 @@ Requested path was: {f} print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') else: setattr(obj, field, saved_value) + if getattr(x, 'init_field', False): + try: + x.init_field(saved_value) + except Exception: + print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: apply_field(x, 'visible') -- cgit v1.2.3 From de29ec0743fcfb141d8891a3ccbd537ea71bf5b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=B5=E3=81=81?= <34892635+fa0311@users.noreply.github.com> Date: Tue, 18 Oct 2022 18:15:00 +0900 Subject: Remove exception handling --- modules/ui.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 39afbc4e..b38bfb3f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1754,10 +1754,7 @@ Requested path was: {f} else: setattr(obj, field, saved_value) if getattr(x, 'init_field', False): - try: - x.init_field(saved_value) - except Exception: - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + x.init_field(saved_value) if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: apply_field(x, 'visible') -- cgit v1.2.3 From 3003438088502774628656790d83fc8074d51ab4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=B5=E3=81=81?= <34892635+fa0311@users.noreply.github.com> Date: Tue, 18 Oct 2022 18:51:57 +0900 Subject: Add visible for dropdown --- modules/ui.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index b38bfb3f..fb6eb5a0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1737,7 +1737,7 @@ Requested path was: {f} print(traceback.format_exc(), file=sys.stderr) def loadsave(path, x): - def apply_field(obj, field, condition=None): + def apply_field(obj, field, condition=None, init_field=None): key = path + "/" + field if getattr(obj,'custom_script_source',None) is not None: @@ -1753,8 +1753,8 @@ Requested path was: {f} print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') else: setattr(obj, field, saved_value) - if getattr(x, 'init_field', False): - x.init_field(saved_value) + if init_field is not None: + init_field(saved_value) if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: apply_field(x, 'visible') @@ -1780,7 +1780,8 @@ Requested path was: {f} # Since there are many dropdowns that shouldn't be saved, # we only mark dropdowns that should be saved. if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False): - apply_field(x, 'value', lambda val: val in x.choices) + apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) + apply_field(x, 'visible') visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") -- cgit v1.2.3 From 02622b19191f5f5112db7633c0630e5c7df1b2f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=B5=E3=81=81?= <34892635+fa0311@users.noreply.github.com> Date: Tue, 18 Oct 2022 18:52:27 +0900 Subject: update scripts.py --- modules/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 3402066d..1039fa9c 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -142,7 +142,7 @@ class ScriptRunner: return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] def init_field(title): - if title == "None": + if title == 'None': return script_index = self.titles.index(title) script = self.scripts[script_index] -- cgit v1.2.3 From 4c605c5174a9b211c3a88e9aff5f5be92b53fd92 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 16 Oct 2022 17:24:06 +0100 Subject: add shared option for update check --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index c0d87168..50dc46ae 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -76,6 +76,7 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) +parser.add_argument("--update-check", action='store_true', help="enable http check to confirm that the currently running version is the most recent release.", default=False) cmd_opts = parser.parse_args() restricted_opts = [ -- cgit v1.2.3 From eb299527b1e5d1f83a14641647fca72e8fb305ac Mon Sep 17 00:00:00 2001 From: yfszzx Date: Tue, 18 Oct 2022 20:14:11 +0800 Subject: Image browser --- modules/images_history.py | 227 ++++++++++++++++++++++++++++++---------------- modules/shared.py | 7 +- modules/ui.py | 2 +- 3 files changed, 154 insertions(+), 82 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 20324557..d56f3a25 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -4,6 +4,7 @@ import time import hashlib import gradio system_bak_path = "webui_log_and_bak" +browser_tabname = "custom" def is_valid_date(date): try: time.strptime(date, "%Y%m%d") @@ -99,13 +100,15 @@ def auto_sorting(dir_name): date_list.append(today) return sorted(date_list, reverse=True) -def archive_images(dir_name, date_to): +def archive_images(dir_name, date_to): + filenames = [] loads_num =int(opts.images_history_num_per_page * opts.images_history_pages_num) + today = time.strftime("%Y%m%d",time.localtime(time.time())) + date_to = today if date_to is None or date_to == "" else date_to + date_to_bak = date_to if opts.images_history_reconstruct_directory: - date_list = auto_sorting(dir_name) - today = time.strftime("%Y%m%d",time.localtime(time.time())) - date_to = today if date_to is None or date_to == "" else date_to + date_list = auto_sorting(dir_name) for date in date_list: if date <= date_to: path = os.path.join(dir_name, date) @@ -120,7 +123,7 @@ def archive_images(dir_name, date_to): tmparray = [(os.path.getmtime(file), file) for file in filenames ] date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 filenames = [] - date_list = {} + date_list = {date_to:None} date = time.strftime("%Y%m%d",time.localtime(time.time())) for t, f in tmparray: date = time.strftime("%Y%m%d",time.localtime(t)) @@ -133,22 +136,29 @@ def archive_images(dir_name, date_to): date = sort_array[loads_num][2] filenames = [x[1] for x in sort_array] else: - date = None if len(sort_array) == 0 else sort_array[-1][2] + date = date_to if len(sort_array) == 0 else sort_array[-1][2] filenames = [x[1] for x in sort_array] - filenames = [x[1] for x in sort_array if x[2]>= date] - _, image_list, _, visible_num = get_recent_images(1, 0, filenames) + filenames = [x[1] for x in sort_array if x[2]>= date] + num = len(filenames) + last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) + date = date[:4] + "-" + date[4:6] + "-" + date[6:8] + date_to_bak = date_to_bak[:4] + "-" + date_to_bak[4:6] + "-" + date_to_bak[6:8] + load_info = f"Loaded {(num + 1) // opts.images_history_pages_num} pades, {num} images, during {date} - {date_to_bak}" + _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) return ( gradio.Dropdown.update(choices=date_list, value=date_to), - date, + load_info, filenames, 1, image_list, "", - visible_num + "", + visible_num, + last_date_from ) -def newest_click(dir_name, date_to): - return archive_images(dir_name, time.strftime("%Y%m%d",time.localtime(time.time()))) + + def delete_image(delete_num, name, filenames, image_index, visible_num): if name == "": @@ -196,7 +206,29 @@ def get_recent_images(page_index, step, filenames): length = len(filenames) visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num - return page_index, image_list, "", visible_num + return page_index, image_list, "", "", visible_num + +def newest_click(date_to): + if date_to is None: + return time.strftime("%Y%m%d",time.localtime(time.time())), [] + else: + return None, [] +def forward_click(last_date_from, date_to_recorder): + if len(date_to_recorder) == 0: + return None, [] + if last_date_from == date_to_recorder[-1]: + date_to_recorder = date_to_recorder[:-1] + if len(date_to_recorder) == 0: + return None, [] + return date_to_recorder[-1], date_to_recorder[:-1] + +def backward_click(last_date_from, date_to_recorder): + if last_date_from is None or last_date_from == "": + return time.strftime("%Y%m%d",time.localtime(time.time())), [] + if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]: + date_to_recorder.append(last_date_from) + return last_date_from, date_to_recorder + def first_page_click(page_index, filenames): return get_recent_images(1, 0, filenames) @@ -214,13 +246,33 @@ def page_index_change(page_index, filenames): return get_recent_images(page_index, 0, filenames) def show_image_info(tabname_box, num, page_index, filenames): - file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] - return file, num, file + file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] + tm = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + return file, tm, num, file def enable_page_buttons(): return gradio.update(visible=True) +def change_dir(img_dir, date_to): + warning = None + try: + if os.path.exists(img_dir): + try: + f = os.listdir(img_dir) + except: + warning = f"'{img_dir} is not a directory" + else: + warning = "The directory is not exist" + except: + warning = "The format of the directory is incorrect" + if warning is None: + today = time.strftime("%Y%m%d",time.localtime(time.time())) + return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today + else: + return gradio.update(visible=True), gradio.update(visible=False), warning, date_to + def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): + custom_dir = False if tabname == "txt2img": dir_name = opts.outdir_txt2img_samples elif tabname == "img2img": @@ -229,69 +281,85 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): dir_name = opts.outdir_extras_samples elif tabname == "saved": dir_name = opts.outdir_save + else: + custom_dir = True + dir_name = None + + if not custom_dir: + d = dir_name.split("/") + dir_name = d[0] + for p in d[1:]: + dir_name = os.path.join(dir_name, p) + if not os.path.exists(dir_name): + os.makedirs(dir_name) - d = dir_name.split("/") - dir_name = d[0] - for p in d[1:]: - dir_name = os.path.join(dir_name, p) - if not os.path.exists(dir_name): - os.makedirs(dir_name) - - with gr.Column() as page_panel: - with gr.Row(visible=False) as turn_page_buttons: - renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") - first_page = gr.Button('First Page') - prev_page = gr.Button('Prev Page') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next Page') - end_page = gr.Button('End Page') - - with gr.Row(elem_id=tabname + "_images_history"): - with gr.Column(scale=2): - with gr.Row(): - newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") - date_from = gr.Textbox(label="Date from", interactive=False) - date_to = gr.Dropdown(label="Date to") - - history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) - with gr.Row(): - delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") - delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - - with gr.Column(): - with gr.Row(): - if tabname != "saved": - save_btn = gr.Button('Save') - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') - with gr.Row(): - with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info", interactive=False) - img_file_name = gr.Textbox(value="", label="File Name", interactive=False) + with gr.Column() as page_panel: + with gr.Row(): + img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory") + with gr.Row(visible=False) as warning: + warning_box = gr.Textbox("Message", interactive=False) + with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: + with gr.Column(scale=2): + with gr.Row(): + backward = gr.Button('Backward') + date_to = gr.Dropdown(label="Date to") + forward = gr.Button('Forward') + newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") + with gr.Row(): + load_info = gr.Textbox(show_label=False, interactive=False) + with gr.Row(visible=False) as turn_page_buttons: + renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") + first_page = gr.Button('First Page') + prev_page = gr.Button('Prev Page') + page_index = gr.Number(value=1, label="Page Index") + next_page = gr.Button('Next Page') + end_page = gr.Button('End Page') + + history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num) + with gr.Row(): + delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") + delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - # hiden items - with gr.Row(visible=False): - visible_img_num = gr.Number() - img_path = gr.Textbox(dir_name) - tabname_box = gr.Textbox(tabname) - image_index = gr.Textbox(value=-1) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") - filenames = gr.State() - all_images_list = gr.State() - hidden = gr.Image(type="pil") - info1 = gr.Textbox() - info2 = gr.Textbox() + with gr.Column(): + with gr.Row(): + if tabname != "saved": + save_btn = gr.Button('Save') + pnginfo_send_to_txt2img = gr.Button('Send to txt2img') + pnginfo_send_to_img2img = gr.Button('Send to img2img') + with gr.Row(): + with gr.Column(): + img_file_info = gr.Textbox(label="Generate Info", interactive=False) + img_file_name = gr.Textbox(value="", label="File Name", interactive=False) + img_file_time= gr.Textbox(value="", label="Create Time", interactive=False) - + + # hiden items + with gr.Row(): #visible=False): + visible_img_num = gr.Number() + date_to_recorder = gr.State([]) + last_date_from = gr.Textbox() + tabname_box = gr.Textbox(tabname) + image_index = gr.Textbox(value=-1) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") + filenames = gr.State() + all_images_list = gr.State() + hidden = gr.Image(type="pil") + info1 = gr.Textbox() + info2 = gr.Textbox() + + img_path.submit(change_dir, inputs=[img_path, date_to], outputs=[warning, main_panel, warning_box, date_to]) #change date - change_date_output = [date_to, date_from, filenames, page_index, history_gallery, img_file_name, visible_img_num] - newest.click(newest_click, inputs=[img_path, date_to], outputs=change_date_output) - date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output) - newest.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) - newest.click(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) + change_date_output = [date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from] + + date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output) + date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) + date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + + newest.click(newest_click, inputs=[date_to], outputs=[date_to, date_to_recorder]) + forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) + backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) + #delete delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) @@ -301,7 +369,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): #turn page gallery_inputs = [page_index, filenames] - gallery_outputs = [page_index, history_gallery, img_file_name, visible_img_num] + gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num] first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs) next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs) prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) @@ -317,12 +385,14 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") # other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden]) + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden]) img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') + + def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): global opts; opts = sys_opts @@ -330,10 +400,11 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: - for tab in ["txt2img", "img2img", "extras", "saved"]: + for tab in [browser_tabname, "txt2img", "img2img", "extras", "saved"]: with gr.Tab(tab): - with gr.Blocks(analytics_enabled=False) as images_history_img2img: + with gr.Blocks(analytics_enabled=False) : show_images_history(gr, opts, tab, run_pnginfo, switch_dict) - gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False) + #gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False) + gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_finish_render", visible=False) return images_history diff --git a/modules/shared.py b/modules/shared.py index c2ea4186..1811018d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -309,10 +309,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) -options_templates.update(options_section(('images-history', "Images history"), { - "images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), +options_templates.update(options_section(('images-history', "Images Browser"), { + #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"), - "images_history_pages_num": OptionInfo(6, "Maximum number of pages per load "), + "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "), + "images_history_grid_num": OptionInfo(6, "Number of grids in each row"), })) diff --git a/modules/ui.py b/modules/ui.py index 43dc88fc..85abac4d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1548,7 +1548,7 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (images_history, "History", "images_history"), + (images_history, "Image Browser", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), (settings_interface, "Settings", "settings"), -- cgit v1.2.3 From 433a7525c1f5eb5963340e0cc45d31038ede3f7e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 18 Oct 2022 15:18:02 +0300 Subject: remove shared option for update check (because it is not an argument of webui) have launch.py examine both COMMANDLINE_ARGS as well as argv for its arguments --- modules/shared.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 50dc46ae..c0d87168 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -76,7 +76,6 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) -parser.add_argument("--update-check", action='store_true', help="enable http check to confirm that the currently running version is the most recent release.", default=False) cmd_opts = parser.parse_args() restricted_opts = [ -- cgit v1.2.3 From 2f448d97a9427f9a7bad19cf608561b2878ab2da Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 17 Oct 2022 23:18:21 +0900 Subject: styles.csv encoding utf8 to utf-8-sig utf-8-bom for better compatibility for some programs --- modules/styles.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/styles.py b/modules/styles.py index d44dfc1a..3bf5c5b6 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -45,7 +45,7 @@ class StyleDatabase: if not os.path.exists(path): return - with open(path, "r", encoding="utf8", newline='') as file: + with open(path, "r", encoding="utf-8-sig", newline='') as file: reader = csv.DictReader(file) for row in reader: # Support loading old CSV format with "name, text"-columns @@ -79,7 +79,7 @@ class StyleDatabase: def save_styles(self, path: str) -> None: # Write to temporary file first, so we don't nuke the file if something goes wrong fd, temp_path = tempfile.mkstemp(".csv") - with os.fdopen(fd, "w", encoding="utf8", newline='') as file: + with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) -- cgit v1.2.3 From e20b7e30fe17744acb74ad33c87c0963525ea921 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 18 Oct 2022 15:33:24 +0300 Subject: fix for add difference model merging --- modules/extras.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index c908b43e..03f6085e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -216,8 +216,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam if theta_func1: for key in tqdm.tqdm(theta_1.keys()): if 'model' in key: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) + if key in theta_2: + t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) + theta_1[key] = theta_func1(theta_1[key], t2) + else: + theta_1[key] = 0 del theta_2, teritary_model for key in tqdm.tqdm(theta_0.keys()): -- cgit v1.2.3 From ec1924ee5789b72c31c65932b549c59ccae0cdd6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 18 Oct 2022 16:05:52 +0300 Subject: additional fix for difference model merging --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 03f6085e..b853fa5b 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -220,7 +220,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) theta_1[key] = theta_func1(theta_1[key], t2) else: - theta_1[key] = 0 + theta_1[key] = torch.zeros_like(theta_1[key]) del theta_2, teritary_model for key in tqdm.tqdm(theta_0.keys()): -- cgit v1.2.3 From b7e78ef692fe912916de6e54f6e2521b000d650c Mon Sep 17 00:00:00 2001 From: yfszzx Date: Tue, 18 Oct 2022 22:21:54 +0800 Subject: Image browser improve --- modules/images_history.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index d56f3a25..a40cdc0e 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -100,14 +100,15 @@ def auto_sorting(dir_name): date_list.append(today) return sorted(date_list, reverse=True) -def archive_images(dir_name, date_to): - +def archive_images(dir_name, date_to): filenames = [] - loads_num =int(opts.images_history_num_per_page * opts.images_history_pages_num) + batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num) + if batch_size <= 0: + batch_size = opts.images_history_num_per_page * 6 today = time.strftime("%Y%m%d",time.localtime(time.time())) date_to = today if date_to is None or date_to == "" else date_to date_to_bak = date_to - if opts.images_history_reconstruct_directory: + if False: #opts.images_history_reconstruct_directory: date_list = auto_sorting(dir_name) for date in date_list: if date <= date_to: @@ -115,11 +116,13 @@ def archive_images(dir_name, date_to): if date == today and not os.path.exists(path): continue filenames = traverse_all_files(path, filenames) - if len(filenames) > loads_num: + if len(filenames) > batch_size: break filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file)) else: - filenames = traverse_all_files(dir_name, filenames) + filenames = traverse_all_files(dir_name, filenames) + total_num = len(filenames) + batch_count = len(filenames) + 1 // batch_size + 1 tmparray = [(os.path.getmtime(file), file) for file in filenames ] date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 filenames = [] @@ -132,8 +135,8 @@ def archive_images(dir_name, date_to): filenames.append((t, f ,date)) date_list = sorted(list(date_list.keys()), reverse=True) sort_array = sorted(filenames, key=lambda x:-x[0]) - if len(sort_array) > loads_num: - date = sort_array[loads_num][2] + if len(sort_array) > batch_size: + date = sort_array[batch_size][2] filenames = [x[1] for x in sort_array] else: date = date_to if len(sort_array) == 0 else sort_array[-1][2] @@ -141,9 +144,9 @@ def archive_images(dir_name, date_to): filenames = [x[1] for x in sort_array if x[2]>= date] num = len(filenames) last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) - date = date[:4] + "-" + date[4:6] + "-" + date[6:8] - date_to_bak = date_to_bak[:4] + "-" + date_to_bak[4:6] + "-" + date_to_bak[6:8] - load_info = f"Loaded {(num + 1) // opts.images_history_pages_num} pades, {num} images, during {date} - {date_to_bak}" + date = date[:4] + "/" + date[4:6] + "/" + date[6:8] + date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8] + load_info = f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) return ( gradio.Dropdown.update(choices=date_list, value=date_to), @@ -154,12 +157,10 @@ def archive_images(dir_name, date_to): "", "", visible_num, - last_date_from + last_date_from, + #gradio.update(visible=batch_count > 1) ) - - - def delete_image(delete_num, name, filenames, image_index, visible_num): if name == "": return filenames, delete_num @@ -295,16 +296,16 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Column() as page_panel: with gr.Row(): - img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory") + img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) with gr.Row(visible=False) as warning: warning_box = gr.Textbox("Message", interactive=False) with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: with gr.Column(scale=2): - with gr.Row(): - backward = gr.Button('Backward') - date_to = gr.Dropdown(label="Date to") - forward = gr.Button('Forward') + with gr.Row() as batch_panel: + forward = gr.Button('Forward') + date_to = gr.Dropdown(label="Date to") + backward = gr.Button('Backward') newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") with gr.Row(): load_info = gr.Textbox(show_label=False, interactive=False) @@ -335,7 +336,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): # hiden items - with gr.Row(): #visible=False): + with gr.Row(visible=False): visible_img_num = gr.Number() date_to_recorder = gr.State([]) last_date_from = gr.Textbox() -- cgit v1.2.3 From cbf15edbf90a68a08eeab40af5df577ba4ac90b6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 18 Oct 2022 17:23:38 +0300 Subject: remove dependence on TQDM for sampler progress/interrupt functionality --- modules/processing.py | 6 --- modules/sd_samplers.py | 107 +++++++++++++++++++++++++++---------------------- 2 files changed, 58 insertions(+), 55 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index deb6125e..346eea88 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -402,12 +402,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) - if state.interrupted or state.skipped: - - # if we are interrupted, sample returns just noise - # use the image collected previously in sampler loop - samples_ddim = shared.state.current_latent - samples_ddim = samples_ddim.to(devices.dtype_vae) x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 20309e06..b58e810b 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -98,25 +98,8 @@ def store_latent(decoded): shared.state.current_image = sample_to_image(decoded) - -def extended_tdqm(sequence, *args, desc=None, **kwargs): - state.sampling_steps = len(sequence) - state.sampling_step = 0 - - seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs) - - for x in seq: - if state.interrupted or state.skipped: - break - - yield x - - state.sampling_step += 1 - shared.total_tqdm.update() - - -ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs) -ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs) +class InterruptedException(BaseException): + pass class VanillaStableDiffusionSampler: @@ -128,14 +111,32 @@ class VanillaStableDiffusionSampler: self.init_latent = None self.sampler_noises = None self.step = 0 + self.stop_at = None self.eta = None self.default_eta = 0.0 self.config = None + self.last_latent = None def number_of_needed_noises(self, p): return 0 + def launch_sampling(self, steps, func): + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except InterruptedException: + return self.last_latent + def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): + if state.interrupted or state.skipped: + raise InterruptedException + + if self.stop_at is not None and self.step > self.stop_at: + raise InterruptedException + + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) @@ -159,11 +160,16 @@ class VanillaStableDiffusionSampler: res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) if self.mask is not None: - store_latent(self.init_latent * self.mask + self.nmask * res[1]) + self.last_latent = self.init_latent * self.mask + self.nmask * res[1] else: - store_latent(res[1]) + self.last_latent = res[1] + + store_latent(self.last_latent) self.step += 1 + state.sampling_step = self.step + shared.total_tqdm.update() + return res def initialize(self, p): @@ -192,7 +198,7 @@ class VanillaStableDiffusionSampler: self.init_latent = x self.step = 0 - samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning) + samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) return samples @@ -206,9 +212,9 @@ class VanillaStableDiffusionSampler: # existing code fails with certain step counts, like 9 try: - samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta) + samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) except Exception: - samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta) + samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) return samples_ddim @@ -223,6 +229,9 @@ class CFGDenoiser(torch.nn.Module): self.step = 0 def forward(self, x, sigma, uncond, cond, cond_scale): + if state.interrupted or state.skipped: + raise InterruptedException + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) @@ -268,25 +277,6 @@ class CFGDenoiser(torch.nn.Module): return denoised -def extended_trange(sampler, count, *args, **kwargs): - state.sampling_steps = count - state.sampling_step = 0 - - seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs) - - for x in seq: - if state.interrupted or state.skipped: - break - - if sampler.stop_at is not None and x > sampler.stop_at: - break - - yield x - - state.sampling_step += 1 - shared.total_tqdm.update() - - class TorchHijack: def __init__(self, kdiff_sampler): self.kdiff_sampler = kdiff_sampler @@ -314,9 +304,28 @@ class KDiffusionSampler: self.eta = None self.default_eta = 1.0 self.config = None + self.last_latent = None def callback_state(self, d): - store_latent(d["denoised"]) + step = d['i'] + latent = d["denoised"] + store_latent(latent) + self.last_latent = latent + + if self.stop_at is not None and step > self.stop_at: + raise InterruptedException + + state.sampling_step = step + shared.total_tqdm.update() + + def launch_sampling(self, steps, func): + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except InterruptedException: + return self.last_latent def number_of_needed_noises(self, p): return p.steps @@ -339,9 +348,6 @@ class KDiffusionSampler: self.sampler_noise_index = 0 self.eta = p.eta or opts.eta_ancestral - if hasattr(k_diffusion.sampling, 'trange'): - k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs) - if self.sampler_noises is not None: k_diffusion.sampling.torch = TorchHijack(self) @@ -383,8 +389,9 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x - return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) + return samples def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): steps = steps or p.steps @@ -406,6 +413,8 @@ class KDiffusionSampler: extra_params_kwargs['n'] = steps else: extra_params_kwargs['sigmas'] = sigmas - samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) + + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) + return samples -- cgit v1.2.3 From 6021f7a75f7b5208a2be15cda5526028152f922d Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 00:51:36 +0900 Subject: add options to custom hypernetwork layer structure --- modules/hypernetworks/hypernetwork.py | 88 ++++++++++++++++++++++++++--------- modules/shared.py | 4 +- 2 files changed, 70 insertions(+), 22 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4905710e..cadb9911 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -1,52 +1,98 @@ +import csv import datetime import glob import html import os import sys import traceback -import tqdm -import csv +import modules.textual_inversion.dataset import torch - -from ldm.util import default -from modules import devices, shared, processing, sd_models -import torch -from torch import einsum +import tqdm from einops import rearrange, repeat -import modules.textual_inversion.dataset +from ldm.util import default +from modules import devices, processing, sd_models, shared from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler +from torch import einsum + + +def parse_layer_structure(dim, state_dict): + i = 0 + res = [1] + while (key := "linear.{}.weight".format(i)) in state_dict: + weight = state_dict[key] + res.append(len(weight) // dim) + i += 1 + return res class HypernetworkModule(torch.nn.Module): multiplier = 1.0 + layer_structure = None + add_layer_norm = False def __init__(self, dim, state_dict=None): super().__init__() + if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None: + layer_structure = (1, 2, 1) + else: + if self.layer_structure is not None: + assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" + assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" + layer_structure = self.layer_structure + else: + layer_structure = parse_layer_structure(dim, state_dict) + + linears = [] + for i in range(len(layer_structure) - 1): + linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + if self.add_layer_norm: + linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) + self.linear = torch.nn.Sequential(*linears) if state_dict is not None: - self.load_state_dict(state_dict, strict=True) + try: + self.load_state_dict(state_dict) + except RuntimeError: + self.try_load_previous(state_dict) else: - - self.linear1.weight.data.normal_(mean=0.0, std=0.01) - self.linear1.bias.data.zero_() - self.linear2.weight.data.normal_(mean=0.0, std=0.01) - self.linear2.bias.data.zero_() + for layer in self.linear: + layer.weight.data.normal_(mean = 0.0, std = 0.01) + layer.bias.data.zero_() self.to(devices.device) + def try_load_previous(self, state_dict): + states = self.state_dict() + states['linear.0.bias'].copy_(state_dict['linear1.bias']) + states['linear.0.weight'].copy_(state_dict['linear1.weight']) + states['linear.1.bias'].copy_(state_dict['linear2.bias']) + states['linear.1.weight'].copy_(state_dict['linear2.weight']) + def forward(self, x): - return x + (self.linear2(self.linear1(x))) * self.multiplier + return x + self.linear(x) * self.multiplier + + def trainables(self): + res = [] + for layer in self.linear: + res += [layer.weight, layer.bias] + return res def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength +def apply_layer_structure(value=None): + HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure + + +def apply_layer_norm(value=None): + HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm + + class Hypernetwork: filename = None name = None @@ -68,7 +114,7 @@ class Hypernetwork: for k, layers in self.layers.items(): for layer in layers: layer.train() - res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] + res += layer.trainables() return res @@ -226,7 +272,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) - + assert ds.length > 1, "Dataset should contain more than 1 images" if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) @@ -261,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log with torch.autocast("cuda"): c = stack_conds([entry.cond for entry in entries]).to(devices.device) -# c = torch.vstack([entry.cond for entry in entries]).to(devices.device) + c = torch.vstack([entry.cond for entry in entries]).to(devices.device) x = torch.stack([entry.latent for entry in entries]).to(devices.device) loss = shared.sd_model(x, c)[0] del x @@ -283,7 +329,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{mean_loss:.7f}", - "learn_rate": scheduler.learn_rate + "learn_rate": f"{scheduler.learn_rate:.7f}" }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: diff --git a/modules/shared.py b/modules/shared.py index c0d87168..c87ce70e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization +from modules import sd_models, sd_samplers, localization from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -258,6 +258,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}), + "sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), -- cgit v1.2.3 From a5611ea5026bd8e12d8e84023384c369d0511dda Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 01:00:01 +0900 Subject: update --- modules/hypernetworks/hypernetwork.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index cadb9911..c5835bce 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -1,20 +1,22 @@ -import csv import datetime import glob import html import os import sys import traceback +import tqdm +import csv -import modules.textual_inversion.dataset import torch -import tqdm -from einops import rearrange, repeat + from ldm.util import default -from modules import devices, processing, sd_models, shared +from modules import devices, shared, processing, sd_models +import torch +from torch import einsum +from einops import rearrange, repeat +import modules.textual_inversion.dataset from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler -from torch import einsum def parse_layer_structure(dim, state_dict): -- cgit v1.2.3 From 7f2095c6c8db82a5c9cd7c7177f6ba856a2cc676 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 01:01:22 +0900 Subject: update --- modules/shared.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index c87ce70e..6b6d5c41 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_models, sd_samplers, localization +from modules import sd_samplers, sd_models, localization from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -135,7 +135,7 @@ class State: self.job_no += 1 self.sampling_step = 0 self.current_image_sampling_step = 0 - + def get_job_timestamp(self): return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? -- cgit v1.2.3 From e40ba281f1b419cf99552962ea01d87d699840a5 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 01:03:58 +0900 Subject: update --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c5835bce..082165f4 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -309,7 +309,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log with torch.autocast("cuda"): c = stack_conds([entry.cond for entry in entries]).to(devices.device) - c = torch.vstack([entry.cond for entry in entries]).to(devices.device) + # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) x = torch.stack([entry.latent for entry in entries]).to(devices.device) loss = shared.sd_model(x, c)[0] del x @@ -331,7 +331,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{mean_loss:.7f}", - "learn_rate": f"{scheduler.learn_rate:.7f}" + "learn_rate": scheduler.learn_rate }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: -- cgit v1.2.3 From e7f4808505f7a6339927c32b9a0c01bc9134bdeb Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Tue, 18 Oct 2022 19:04:56 +0000 Subject: provide sampler by name --- modules/api/api.py | 12 ++++++++++-- modules/api/processing.py | 16 ++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index ce98cb8c..ff9df0d1 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,14 +1,17 @@ from modules.api.processing import StableDiffusionProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, process_images +from modules.sd_samplers import samplers_k_diffusion import modules.shared as shared import uvicorn -from fastapi import Body, APIRouter +from fastapi import Body, APIRouter, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, Json import json import io import base64 +sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None) + class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: Json @@ -23,9 +26,14 @@ class Api: self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): + sampler_index = sampler_to_index(txt2imgreq.sampler_index) + + if sampler_index is None: + raise HTTPException(status_code=404, detail="Sampler not found") + populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": 0, + "sampler_index": sampler_index[0], "do_not_save_samples": True, "do_not_save_grid": True } diff --git a/modules/api/processing.py b/modules/api/processing.py index b6798241..2e6483ee 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -42,7 +42,8 @@ class PydanticModelGenerator: def __init__( self, model_name: str = None, - class_instance = None + class_instance = None, + additional_fields = None, ): def field_type_generator(k, v): # field_type = str if not overrides.get(k) else overrides[k]["type"] @@ -70,6 +71,13 @@ class PydanticModelGenerator: ) for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] + + for fields in additional_fields: + self._model_def.append(ModelDef( + field=underscore(fields["key"]), + field_alias=fields["key"], + field_type=fields["type"], + field_value=fields["default"])) def generate_model(self): """ @@ -84,4 +92,8 @@ class PydanticModelGenerator: DynamicModel.__config__.allow_mutation = True return DynamicModel -StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model() +StableDiffusionProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingTxt2Img", + StableDiffusionProcessingTxt2Img, + [{"key": "sampler_index", "type": str, "default": "k_euler_a"}] +).generate_model() -- cgit v1.2.3 From 538bc89c269743e56b07ef2b471d1ce0a39b6776 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Wed, 19 Oct 2022 11:27:51 +0800 Subject: Image browser improved --- modules/images_history.py | 135 +++++++++++++++++++++++++--------------------- modules/shared.py | 5 ++ modules/ui.py | 2 +- 3 files changed, 80 insertions(+), 62 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index a40cdc0e..78fd0543 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -4,7 +4,9 @@ import time import hashlib import gradio system_bak_path = "webui_log_and_bak" -browser_tabname = "custom" +custom_tab_name = "custom fold" +faverate_tab_name = "favorites" +tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name] def is_valid_date(date): try: time.strptime(date, "%Y%m%d") @@ -122,7 +124,6 @@ def archive_images(dir_name, date_to): else: filenames = traverse_all_files(dir_name, filenames) total_num = len(filenames) - batch_count = len(filenames) + 1 // batch_size + 1 tmparray = [(os.path.getmtime(file), file) for file in filenames ] date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 filenames = [] @@ -146,10 +147,12 @@ def archive_images(dir_name, date_to): last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) date = date[:4] + "/" + date[4:6] + "/" + date[6:8] date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8] - load_info = f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" + load_info = "
" + load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" + load_info += "
" _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) return ( - gradio.Dropdown.update(choices=date_list, value=date_to), + date_to, load_info, filenames, 1, @@ -158,7 +161,7 @@ def archive_images(dir_name, date_to): "", visible_num, last_date_from, - #gradio.update(visible=batch_count > 1) + gradio.update(visible=total_num > num) ) def delete_image(delete_num, name, filenames, image_index, visible_num): @@ -209,7 +212,7 @@ def get_recent_images(page_index, step, filenames): visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num return page_index, image_list, "", "", visible_num -def newest_click(date_to): +def loac_batch_click(date_to): if date_to is None: return time.strftime("%Y%m%d",time.localtime(time.time())), [] else: @@ -248,7 +251,7 @@ def page_index_change(page_index, filenames): def show_image_info(tabname_box, num, page_index, filenames): file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] - tm = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + tm = "
" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "
" return file, tm, num, file def enable_page_buttons(): @@ -268,9 +271,9 @@ def change_dir(img_dir, date_to): warning = "The format of the directory is incorrect" if warning is None: today = time.strftime("%Y%m%d",time.localtime(time.time())) - return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today + return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True) else: - return gradio.update(visible=True), gradio.update(visible=False), warning, date_to + return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False) def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): custom_dir = False @@ -280,7 +283,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): dir_name = opts.outdir_img2img_samples elif tabname == "extras": dir_name = opts.outdir_extras_samples - elif tabname == "saved": + elif tabname == faverate_tab_name: dir_name = opts.outdir_save else: custom_dir = True @@ -295,22 +298,26 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): os.makedirs(dir_name) with gr.Column() as page_panel: - with gr.Row(): - img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) + with gr.Row(): + with gr.Column(scale=1, visible=not custom_dir) as load_batch_box: + load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True) + with gr.Column(scale=4): + with gr.Row(): + img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) + with gr.Row(): + with gr.Column(visible=False, scale=1) as batch_panel: + with gr.Row(): + forward = gr.Button('Prev batch') + backward = gr.Button('Next batch') + with gr.Column(scale=3): + load_info = gr.HTML(visible=not custom_dir) with gr.Row(visible=False) as warning: warning_box = gr.Textbox("Message", interactive=False) with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: - with gr.Column(scale=2): - with gr.Row() as batch_panel: - forward = gr.Button('Forward') - date_to = gr.Dropdown(label="Date to") - backward = gr.Button('Backward') - newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") - with gr.Row(): - load_info = gr.Textbox(show_label=False, interactive=False) - with gr.Row(visible=False) as turn_page_buttons: - renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") + with gr.Column(scale=2): + with gr.Row(visible=True) as turn_page_buttons: + #date_to = gr.Dropdown(label="Date to") first_page = gr.Button('First Page') prev_page = gr.Button('Prev Page') page_index = gr.Number(value=1, label="Page Index") @@ -322,50 +329,54 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - with gr.Column(): - with gr.Row(): - if tabname != "saved": - save_btn = gr.Button('Save') - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') + with gr.Column(): with gr.Row(): with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info", interactive=False) + img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6) + gr.HTML("
") img_file_name = gr.Textbox(value="", label="File Name", interactive=False) - img_file_time= gr.Textbox(value="", label="Create Time", interactive=False) - + img_file_time= gr.HTML() + with gr.Row(): + if tabname != faverate_tab_name: + save_btn = gr.Button('Collect') + pnginfo_send_to_txt2img = gr.Button('Send to txt2img') + pnginfo_send_to_img2img = gr.Button('Send to img2img') + - # hiden items - with gr.Row(visible=False): - visible_img_num = gr.Number() - date_to_recorder = gr.State([]) - last_date_from = gr.Textbox() - tabname_box = gr.Textbox(tabname) - image_index = gr.Textbox(value=-1) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") - filenames = gr.State() - all_images_list = gr.State() - hidden = gr.Image(type="pil") - info1 = gr.Textbox() - info2 = gr.Textbox() - - img_path.submit(change_dir, inputs=[img_path, date_to], outputs=[warning, main_panel, warning_box, date_to]) - #change date - change_date_output = [date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from] + # hiden items + with gr.Row(visible=False): + renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") + batch_date_to = gr.Textbox(label="Date to") + visible_img_num = gr.Number() + date_to_recorder = gr.State([]) + last_date_from = gr.Textbox() + tabname_box = gr.Textbox(tabname) + image_index = gr.Textbox(value=-1) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") + filenames = gr.State() + all_images_list = gr.State() + hidden = gr.Image(type="pil") + info1 = gr.Textbox() + info2 = gr.Textbox() + + img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info]) + + #change batch + change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel] - date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output) - date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) - date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output) + batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) + batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - newest.click(newest_click, inputs=[date_to], outputs=[date_to, date_to_recorder]) - forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) - backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) + load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder]) + forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) + backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) #delete delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None) - if tabname != "saved": + if tabname != faverate_tab_name: save_btn.click(save_image, inputs=[img_file_name], outputs=None) #turn page @@ -394,18 +405,20 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): -def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): +def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict): global opts; opts = sys_opts loads_files_num = int(opts.images_history_num_per_page) num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) + if cmp_ops.browse_all_images: + tabs_list.append(custom_tab_name) with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: - for tab in [browser_tabname, "txt2img", "img2img", "extras", "saved"]: + for tab in tabs_list: with gr.Tab(tab): with gr.Blocks(analytics_enabled=False) : - show_images_history(gr, opts, tab, run_pnginfo, switch_dict) - #gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False) - gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_finish_render", visible=False) - + show_images_history(gr, opts, tab, run_pnginfo, switch_dict) + gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False) + gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False) + return images_history diff --git a/modules/shared.py b/modules/shared.py index 1811018d..4d735414 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,6 +74,10 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) +parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False) + + +cmd_opts = parser.parse_args() cmd_opts = parser.parse_args() @@ -311,6 +315,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" options_templates.update(options_section(('images-history', "Images Browser"), { #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), + "images_history_preload": OptionInfo(False, "Preload images at startup"), "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"), "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "), "images_history_grid_num": OptionInfo(6, "Number of grids in each row"), diff --git a/modules/ui.py b/modules/ui.py index 85abac4d..88f46659 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1150,7 +1150,7 @@ def create_ui(wrap_gradio_gpu_call): "i2i":img2img_paste_fields } - images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): -- cgit v1.2.3 From 0f0d6ab8e06898ce066251fc769fe14e77e98ced Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Wed, 19 Oct 2022 05:19:01 +0000 Subject: call sampler by name --- modules/api/api.py | 11 ++++++----- modules/api/processing.py | 6 +++--- 2 files changed, 9 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index ff9df0d1..5b0c934e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,6 +1,7 @@ from modules.api.processing import StableDiffusionProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, process_images -from modules.sd_samplers import samplers_k_diffusion +from modules.sd_samplers import all_samplers +from modules.extras import run_pnginfo import modules.shared as shared import uvicorn from fastapi import Body, APIRouter, HTTPException @@ -10,7 +11,7 @@ import json import io import base64 -sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None) +sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") @@ -53,13 +54,13 @@ class Api: - def img2imgendoint(self): + def img2imgapi(self): raise NotImplementedError - def extrasendoint(self): + def extrasapi(self): raise NotImplementedError - def pnginfoendoint(self): + def pnginfoapi(self): raise NotImplementedError def launch(self, server_name, port): diff --git a/modules/api/processing.py b/modules/api/processing.py index 2e6483ee..4c541241 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -1,7 +1,7 @@ from inflection import underscore from typing import Any, Dict, Optional from pydantic import BaseModel, Field, create_model -from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +from modules.processing import StableDiffusionProcessingTxt2Img import inspect @@ -95,5 +95,5 @@ class PydanticModelGenerator: StableDiffusionProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "k_euler_a"}] -).generate_model() + [{"key": "sampler_index", "type": str, "default": "Euler"}] +).generate_model() \ No newline at end of file -- cgit v1.2.3 From 10aca1ca3e81e69e08f556a500c3dc603451429b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 19 Oct 2022 08:42:22 +0300 Subject: more careful loading of model weights (eliminates some issues with checkpoints that have weird cond_stage_model layer names) --- modules/sd_models.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 3aa21ec1..7ad6d474 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -122,11 +122,33 @@ def select_checkpoint(): return checkpoint_info +chckpoint_dict_replacements = { + 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', + 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', + 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', +} + + +def transform_checkpoint_dict_key(k): + for text, replacement in chckpoint_dict_replacements.items(): + if k.startswith(text): + k = replacement + k[len(text):] + + return k + + def get_state_dict_from_checkpoint(pl_sd): if "state_dict" in pl_sd: - return pl_sd["state_dict"] + pl_sd = pl_sd["state_dict"] + + sd = {} + for k, v in pl_sd.items(): + new_key = transform_checkpoint_dict_key(k) + + if new_key is not None: + sd[new_key] = v - return pl_sd + return sd def load_model_weights(model, checkpoint_info): @@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info): print(f"Global Step: {pl_sd['global_step']}") sd = get_state_dict_from_checkpoint(pl_sd) - model.load_state_dict(sd, strict=False) + missing, extra = model.load_state_dict(sd, strict=False) if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) -- cgit v1.2.3 From da72becb13e4b750fbcb3d158c3f843311ef9938 Mon Sep 17 00:00:00 2001 From: Silent <16026653+s-ilent@users.noreply.github.com> Date: Wed, 19 Oct 2022 16:14:33 +1030 Subject: Use training width/height when training hypernetworks. --- modules/hypernetworks/hypernetwork.py | 4 ++-- modules/ui.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4905710e..b8695fc1 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -196,7 +196,7 @@ def stack_conds(conds): return torch.stack(conds) -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert hypernetwork_name, 'hypernetwork not selected' path = shared.hypernetworks.get(hypernetwork_name, None) @@ -225,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) diff --git a/modules/ui.py b/modules/ui.py index fb6eb5a0..ca46343f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1341,6 +1341,8 @@ def create_ui(wrap_gradio_gpu_call): batch_size, dataset_directory, log_directory, + training_width, + training_height, steps, create_image_every, save_embedding_every, -- cgit v1.2.3 From 2fd7935ef4ed296db5dfd8c7fea99244816f8cf0 Mon Sep 17 00:00:00 2001 From: Cheka Date: Tue, 18 Oct 2022 20:28:28 -0300 Subject: Remove wrong self reference in CUDA support for invokeai --- modules/sd_hijack_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index a3345bb9..98123fbf 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v): mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch # Divide factor of safety as there's copying and fragmentation - return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) def einsum_op(q, k, v): if q.device.type == 'cuda': -- cgit v1.2.3 From bcfbb33e50a48b237d8d961cc2be038db53774d5 Mon Sep 17 00:00:00 2001 From: Anastasius Date: Mon, 17 Oct 2022 13:35:20 -0700 Subject: Added time left estimation --- modules/ui.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index ca46343f..9a54aa16 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -261,6 +261,15 @@ def wrap_gradio_call(func, extra_outputs=None): return f +def calc_time_left(progress): + if progress == 0: + return "N/A" + else: + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + return time.strftime('%H:%M:%S', time.gmtime(eta-time_since_start)) + + def check_progress_call(id_part): if shared.state.job_count == 0: return "", gr_show(False), gr_show(False), gr_show(False) @@ -272,11 +281,13 @@ def check_progress_call(id_part): if shared.state.sampling_steps > 0: progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + time_left = calc_time_left( progress ) + progress = min(progress, 1) progressbar = "" if opts.show_progressbar: - progressbar = f"""
{str(int(progress*100))+"%" if progress > 0.01 else ""}
""" + progressbar = f"""
{str(int(progress*100))+"% ETA:"+time_left if progress > 0.01 else ""}
""" image = gr_show(False) preview_visibility = gr_show(False) @@ -308,6 +319,7 @@ def check_progress_call_initial(id_part): shared.state.current_latent = None shared.state.current_image = None shared.state.textinfo = None + shared.state.time_start = time.time() return check_progress_call(id_part) -- cgit v1.2.3 From 442dbedc159bb7e9cf94f0c3626f8a409e0a50eb Mon Sep 17 00:00:00 2001 From: Anastasius Date: Tue, 18 Oct 2022 10:38:07 -0700 Subject: Estimated time displayed if jobs take more 60 sec --- modules/ui.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 9a54aa16..fa54110b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -261,13 +261,17 @@ def wrap_gradio_call(func, extra_outputs=None): return f -def calc_time_left(progress): +def calc_time_left(progress, threshold, label, force_display): if progress == 0: - return "N/A" + return "" else: time_since_start = time.time() - shared.state.time_start eta = (time_since_start/progress) - return time.strftime('%H:%M:%S', time.gmtime(eta-time_since_start)) + eta_relative = eta-time_since_start + if eta_relative > threshold or force_display: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + else: + return "" def check_progress_call(id_part): @@ -281,13 +285,15 @@ def check_progress_call(id_part): if shared.state.sampling_steps > 0: progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - time_left = calc_time_left( progress ) + time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display ) + if time_left != "": + shared.state.time_left_force_display = True progress = min(progress, 1) progressbar = "" if opts.show_progressbar: - progressbar = f"""
{str(int(progress*100))+"% ETA:"+time_left if progress > 0.01 else ""}
""" + progressbar = f"""
{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}
""" image = gr_show(False) preview_visibility = gr_show(False) @@ -320,6 +326,7 @@ def check_progress_call_initial(id_part): shared.state.current_image = None shared.state.textinfo = None shared.state.time_start = time.time() + shared.state.time_left_force_display = False return check_progress_call(id_part) -- cgit v1.2.3 From 1d4aa376e6111e90888a30ae24d2bcd7f978ec51 Mon Sep 17 00:00:00 2001 From: Anastasius Date: Tue, 18 Oct 2022 12:42:39 -0700 Subject: Predictable long operation check for time estimation --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index fa54110b..38ba1138 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -268,7 +268,7 @@ def calc_time_left(progress, threshold, label, force_display): time_since_start = time.time() - shared.state.time_start eta = (time_since_start/progress) eta_relative = eta-time_since_start - if eta_relative > threshold or force_display: + if (eta_relative > threshold and progress > 0.02) or force_display: return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) else: return "" -- cgit v1.2.3 From bb0e7232b301d1706bbd0e09367dece3bb7ac07c Mon Sep 17 00:00:00 2001 From: Ikko Ashimine Date: Wed, 19 Oct 2022 02:18:56 +0900 Subject: Fix typo in prompt_parser.py assoicated -> associated --- modules/prompt_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 919d5d31..f70872c4 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -275,7 +275,7 @@ re_attention = re.compile(r""" def parse_prompt_attention(text): """ - Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight. + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. Accepted tokens are: (abc) - increases attention to abc by a multiplier of 1.1 (abc:3.12) - increases attention to abc by a multiplier of 3.12 -- cgit v1.2.3 From f894dd552f68bea27476f1f360ab8e79f3a65b4f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 19 Oct 2022 12:45:30 +0300 Subject: fix for broken checkpoint merger --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 7ad6d474..eae22e87 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -148,7 +148,10 @@ def get_state_dict_from_checkpoint(pl_sd): if new_key is not None: sd[new_key] = v - return sd + pl_sd.clear() + pl_sd.update(sd) + + return pl_sd def load_model_weights(model, checkpoint_info): -- cgit v1.2.3 From abeec4b63029c2c4151a78fc395d312113881845 Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 03:18:26 -0700 Subject: Add auto focal point cropping to Preprocess images This algorithm plots a bunch of points of interest on the source image and averages their locations to find a center. Most points come from OpenCV. One point comes from an entropy model. OpenCV points account for 50% of the weight and the entropy based point is the other 50%. The center of all weighted points is calculated and a bounding box is drawn as close to centered over that point as possible. --- modules/textual_inversion/preprocess.py | 151 ++++++++++++++++++++++++++++++-- 1 file changed, 146 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 886cf0c3..168bfb09 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,5 +1,7 @@ import os -from PIL import Image, ImageOps +import cv2 +import numpy as np +from PIL import Image, ImageOps, ImageDraw import platform import sys import tqdm @@ -11,7 +13,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False): try: if process_caption: shared.interrogator.load() @@ -21,7 +23,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, process_entropy_focus) finally: @@ -33,7 +35,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ -def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False): width = process_width height = process_height src = os.path.abspath(process_src) @@ -93,6 +95,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro is_tall = ratio > 1.35 is_wide = ratio < 1 / 1.35 + processing_option_ran = False + if process_split and is_tall: img = img.resize((width, height * img.height // img.width)) @@ -101,6 +105,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro bot = img.crop((0, img.height - height, width, img.height)) save_pic(bot, index) + + processing_option_ran = True elif process_split and is_wide: img = img.resize((width * img.width // img.height, height)) @@ -109,8 +115,143 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro right = img.crop((img.width - width, 0, img.width, height)) save_pic(right, index) - else: + + processing_option_ran = True + + if process_entropy_focus and (is_tall or is_wide): + if is_tall: + img = img.resize((width, height * img.height // img.width)) + else: + img = img.resize((width * img.width // img.height, height)) + + x_focal_center, y_focal_center = image_central_focal_point(img, width, height) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(height / 2) + x_half = int(width / 2) + + x1 = x_focal_center - x_half + if x1 < 0: + x1 = 0 + elif x1 + width > img.width: + x1 = img.width - width + + y1 = y_focal_center - y_half + if y1 < 0: + y1 = 0 + elif y1 + height > img.height: + y1 = img.height - height + + x2 = x1 + width + y2 = y1 + height + + crop = [x1, y1, x2, y2] + + focal = img.crop(tuple(crop)) + save_pic(focal, index) + + processing_option_ran = True + + if not processing_option_ran: img = images.resize_image(1, img, width, height) save_pic(img, index) shared.state.nextjob() + + +def image_central_focal_point(im, target_width, target_height): + focal_points = [] + + focal_points.extend( + image_focal_points(im) + ) + + fp_entropy = image_entropy_point(im, target_width, target_height) + fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy + + focal_points.append(fp_entropy) + + weight = 0.0 + x = 0.0 + y = 0.0 + for focal_point in focal_points: + weight += focal_point['weight'] + x += focal_point['x'] * focal_point['weight'] + y += focal_point['y'] * focal_point['weight'] + avg_x = round(x // weight) + avg_y = round(y // weight) + + return avg_x, avg_y + + +def image_focal_points(im): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=50, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.05, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append({ + 'x': x, + 'y': y, + 'weight': 1.0 + }) + + return focal_points + + +def image_entropy_point(im, crop_width, crop_height): + img = im.copy() + # just make it easier to slide the test crop with images oriented the same way + if (img.size[0] < img.size[1]): + portrait = True + img = img.rotate(90, expand=1) + + e_max = 0 + crop_current = [0, 0, crop_width, crop_height] + crop_best = crop_current + while crop_current[2] < img.size[0]: + crop = img.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e_max < e): + e_max = e + crop_best = list(crop_current) + + crop_current[0] += 4 + crop_current[2] += 4 + + x_mid = int((crop_best[2] - crop_best[0])/2) + y_mid = int((crop_best[3] - crop_best[1])/2) + + return { + 'x': x_mid, + 'y': y_mid, + 'weight': 1.0 + } + + +def image_entropy(im): + # greyscale image entropy + band = np.asarray(im.convert("L")) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + -- cgit v1.2.3 From 087609ee181a91a523647435ffffa6288a317e2f Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 03:19:35 -0700 Subject: UI changes for focal point image cropping --- modules/ui.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 1ff7eb4f..b6be713b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1234,6 +1234,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') process_split = gr.Checkbox(label='Split oversized images into two') + process_entropy_focus = gr.Checkbox(label='Create auto focal point crop') process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) @@ -1318,7 +1319,8 @@ def create_ui(wrap_gradio_gpu_call): process_flip, process_split, process_caption, - process_caption_deepbooru + process_caption_deepbooru, + process_entropy_focus ], outputs=[ ti_output, -- cgit v1.2.3 From 42fbda83bb9830af18187fddb50c1bedd01da502 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 14:30:33 +0000 Subject: layer options moves into create hnet ui --- modules/hypernetworks/hypernetwork.py | 64 +++++++++++++++++------------------ modules/hypernetworks/ui.py | 9 +++-- modules/shared.py | 2 -- modules/ui.py | 8 +++-- 4 files changed, 45 insertions(+), 38 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 583ada31..7d519cd9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -19,37 +19,21 @@ from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler -def parse_layer_structure(dim, state_dict): - i = 0 - res = [1] - while (key := "linear.{}.weight".format(i)) in state_dict: - weight = state_dict[key] - res.append(len(weight) // dim) - i += 1 - return res - - class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - layer_structure = None - add_layer_norm = False - def __init__(self, dim, state_dict=None): + def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): super().__init__() - if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None: - layer_structure = (1, 2, 1) + if layer_structure is not None: + assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" + assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" else: - if self.layer_structure is not None: - assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" - assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - layer_structure = self.layer_structure - else: - layer_structure = parse_layer_structure(dim, state_dict) + layer_structure = parse_layer_structure(dim, state_dict) linears = [] for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - if self.add_layer_norm: + if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) self.linear = torch.nn.Sequential(*linears) @@ -77,38 +61,47 @@ class HypernetworkModule(torch.nn.Module): return x + self.linear(x) * self.multiplier def trainables(self): - res = [] + layer_structure = [] for layer in self.linear: - res += [layer.weight, layer.bias] - return res + layer_structure += [layer.weight, layer.bias] + return layer_structure def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength -def apply_layer_structure(value=None): - HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure +def parse_layer_structure(dim, state_dict): + i = 0 + layer_structure = [1] + while (key := "linear.{}.weight".format(i)) in state_dict: + weight = state_dict[key] + layer_structure.append(len(weight) // dim) + i += 1 -def apply_layer_norm(value=None): - HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm + return layer_structure class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False): self.filename = None self.name = name self.layers = {} self.step = 0 self.sd_checkpoint = None self.sd_checkpoint_name = None + self.layer_structure = layer_structure + self.add_layer_norm = add_layer_norm for size in enable_sizes or []: - self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) + self.layers[size] = ( + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), + ) def weights(self): res = [] @@ -128,6 +121,8 @@ class Hypernetwork: state_dict['step'] = self.step state_dict['name'] = self.name + state_dict['layer_structure'] = self.layer_structure + state_dict['is_layer_norm'] = self.add_layer_norm state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name @@ -142,10 +137,15 @@ class Hypernetwork: for size, sd in state_dict.items(): if type(size) == int: - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + self.layers[size] = ( + HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]), + HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]), + ) self.name = state_dict.get('name', self.name) self.step = state_dict.get('step', 0) + self.layer_structure = state_dict.get('layer_structure', None) + self.add_layer_norm = state_dict.get('is_layer_norm', False) self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index dfa599af..7e8ea95e 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,11 +9,16 @@ from modules import sd_hijack, shared, devices from modules.hypernetworks import hypernetwork -def create_hypernetwork(name, enable_sizes): +def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" - hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes]) + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( + name=name, + enable_sizes=[int(x) for x in enable_sizes], + layer_structure=layer_structure, + add_layer_norm=add_layer_norm, + ) hypernet.save(fn) shared.reload_hypernetworks() diff --git a/modules/shared.py b/modules/shared.py index 0540cae9..faede821 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -260,8 +260,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), - "sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}), - "sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), diff --git a/modules/ui.py b/modules/ui.py index ca46343f..d9ee462f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -458,14 +458,14 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" ) with gr.Row(): with gr.Column(scale=80): with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" ) @@ -1198,6 +1198,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) + new_hypernetwork_layer_structure = gr.Dropdown(label="Hypernetwork layer structure", choices=[(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]) + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") with gr.Row(): with gr.Column(scale=3): @@ -1280,6 +1282,8 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ new_hypernetwork_name, new_hypernetwork_sizes, + new_hypernetwork_layer_structure, + new_hypernetwork_add_layer_norm, ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From 3770b8d2fa62066d472a04739c7b84bce8538832 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 15:28:42 +0000 Subject: enable to write layer structure of hn himself --- modules/hypernetworks/ui.py | 4 ++++ modules/ui.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 7e8ea95e..08f75f15 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -1,5 +1,6 @@ import html import os +import re import gradio as gr @@ -13,6 +14,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" + if type(layer_structure) == str: + layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure))) + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( name=name, enable_sizes=[int(x) for x in enable_sizes], diff --git a/modules/ui.py b/modules/ui.py index d9ee462f..18a2add0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1198,7 +1198,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) - new_hypernetwork_layer_structure = gr.Dropdown(label="Hypernetwork layer structure", choices=[(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]) + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") with gr.Row(): -- cgit v1.2.3 From 019a3a88f07766f2d32c32fbe8e41625f28ecb5e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 19 Oct 2022 17:15:47 +0100 Subject: Update ui.py --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d2e24880..1573ef82 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1247,7 +1247,7 @@ def create_ui(wrap_gradio_gpu_call): run_preprocess = gr.Button(value="Preprocess", variant='primary') with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork wiki

") with gr.Row(): train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") -- cgit v1.2.3 From c6e9fed5003631c87d548e74d6e359678959a453 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 19 Oct 2022 19:21:16 +0300 Subject: fix for #3086 failing to load any previous hypernet --- modules/hypernetworks/hypernetwork.py | 60 ++++++++++++++++------------------- 1 file changed, 28 insertions(+), 32 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d519cd9..74300122 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,11 +24,10 @@ class HypernetworkModule(torch.nn.Module): def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): super().__init__() - if layer_structure is not None: - assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" - assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - else: - layer_structure = parse_layer_structure(dim, state_dict) + + assert layer_structure is not None, "layer_structure mut not be None" + assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" + assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" linears = [] for i in range(len(layer_structure) - 1): @@ -39,23 +38,30 @@ class HypernetworkModule(torch.nn.Module): self.linear = torch.nn.Sequential(*linears) if state_dict is not None: - try: - self.load_state_dict(state_dict) - except RuntimeError: - self.try_load_previous(state_dict) + self.fix_old_state_dict(state_dict) + self.load_state_dict(state_dict) else: for layer in self.linear: - layer.weight.data.normal_(mean = 0.0, std = 0.01) + layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() self.to(devices.device) - def try_load_previous(self, state_dict): - states = self.state_dict() - states['linear.0.bias'].copy_(state_dict['linear1.bias']) - states['linear.0.weight'].copy_(state_dict['linear1.weight']) - states['linear.1.bias'].copy_(state_dict['linear2.bias']) - states['linear.1.weight'].copy_(state_dict['linear2.weight']) + def fix_old_state_dict(self, state_dict): + changes = { + 'linear1.bias': 'linear.0.bias', + 'linear1.weight': 'linear.0.weight', + 'linear2.bias': 'linear.1.bias', + 'linear2.weight': 'linear.1.weight', + } + + for fr, to in changes.items(): + x = state_dict.get(fr, None) + if x is None: + continue + + del state_dict[fr] + state_dict[to] = x def forward(self, x): return x + self.linear(x) * self.multiplier @@ -71,18 +77,6 @@ def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength -def parse_layer_structure(dim, state_dict): - i = 0 - layer_structure = [1] - - while (key := "linear.{}.weight".format(i)) in state_dict: - weight = state_dict[key] - layer_structure.append(len(weight) // dim) - i += 1 - - return layer_structure - - class Hypernetwork: filename = None name = None @@ -135,17 +129,18 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') + self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) + self.add_layer_norm = state_dict.get('is_layer_norm', False) + for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]), - HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]), + HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm), ) self.name = state_dict.get('name', self.name) self.step = state_dict.get('step', 0) - self.layer_structure = state_dict.get('layer_structure', None) - self.add_layer_norm = state_dict.get('is_layer_norm', False) self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) @@ -244,6 +239,7 @@ def stack_conds(conds): return torch.stack(conds) + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert hypernetwork_name, 'hypernetwork not selected' -- cgit v1.2.3 From 2ce52d32e41fb523d1494f45073fd18496e52d35 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 16:31:12 +0000 Subject: fix for #3086 failing to load any previous hypernet --- modules/hypernetworks/hypernetwork.py | 60 ++++++++++++++++------------------- 1 file changed, 28 insertions(+), 32 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d519cd9..74300122 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,11 +24,10 @@ class HypernetworkModule(torch.nn.Module): def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): super().__init__() - if layer_structure is not None: - assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" - assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - else: - layer_structure = parse_layer_structure(dim, state_dict) + + assert layer_structure is not None, "layer_structure mut not be None" + assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" + assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" linears = [] for i in range(len(layer_structure) - 1): @@ -39,23 +38,30 @@ class HypernetworkModule(torch.nn.Module): self.linear = torch.nn.Sequential(*linears) if state_dict is not None: - try: - self.load_state_dict(state_dict) - except RuntimeError: - self.try_load_previous(state_dict) + self.fix_old_state_dict(state_dict) + self.load_state_dict(state_dict) else: for layer in self.linear: - layer.weight.data.normal_(mean = 0.0, std = 0.01) + layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() self.to(devices.device) - def try_load_previous(self, state_dict): - states = self.state_dict() - states['linear.0.bias'].copy_(state_dict['linear1.bias']) - states['linear.0.weight'].copy_(state_dict['linear1.weight']) - states['linear.1.bias'].copy_(state_dict['linear2.bias']) - states['linear.1.weight'].copy_(state_dict['linear2.weight']) + def fix_old_state_dict(self, state_dict): + changes = { + 'linear1.bias': 'linear.0.bias', + 'linear1.weight': 'linear.0.weight', + 'linear2.bias': 'linear.1.bias', + 'linear2.weight': 'linear.1.weight', + } + + for fr, to in changes.items(): + x = state_dict.get(fr, None) + if x is None: + continue + + del state_dict[fr] + state_dict[to] = x def forward(self, x): return x + self.linear(x) * self.multiplier @@ -71,18 +77,6 @@ def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength -def parse_layer_structure(dim, state_dict): - i = 0 - layer_structure = [1] - - while (key := "linear.{}.weight".format(i)) in state_dict: - weight = state_dict[key] - layer_structure.append(len(weight) // dim) - i += 1 - - return layer_structure - - class Hypernetwork: filename = None name = None @@ -135,17 +129,18 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') + self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) + self.add_layer_norm = state_dict.get('is_layer_norm', False) + for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]), - HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]), + HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm), ) self.name = state_dict.get('name', self.name) self.step = state_dict.get('step', 0) - self.layer_structure = state_dict.get('layer_structure', None) - self.add_layer_norm = state_dict.get('is_layer_norm', False) self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) @@ -244,6 +239,7 @@ def stack_conds(conds): return torch.stack(conds) + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert hypernetwork_name, 'hypernetwork not selected' -- cgit v1.2.3 From 57eb1a64c85d995cacb4fa3832e87405bf6820b9 Mon Sep 17 00:00:00 2001 From: Alexandre Simard Date: Wed, 19 Oct 2022 12:28:27 -0400 Subject: Update ui.py --- modules/ui.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d2e24880..c9a923ab 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -268,8 +268,13 @@ def calc_time_left(progress, threshold, label, force_display): time_since_start = time.time() - shared.state.time_start eta = (time_since_start/progress) eta_relative = eta-time_since_start - if (eta_relative > threshold and progress > 0.02) or force_display: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + if (eta_relative > threshold and progress > 0.02) or force_display: + if eta_relative > 3600: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + elif eta_relative > 60: + return label + time.strftime('%M:%S', time.gmtime(eta_relative)) + else: + return label + time.strftime('%Ss', time.gmtime(eta_relative)) else: return "" @@ -285,7 +290,7 @@ def check_progress_call(id_part): if shared.state.sampling_steps > 0: progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display ) + time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display ) if time_left != "": shared.state.time_left_force_display = True @@ -293,7 +298,7 @@ def check_progress_call(id_part): progressbar = "" if opts.show_progressbar: - progressbar = f"""
{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}
""" + progressbar = f"""
{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}
""" image = gr_show(False) preview_visibility = gr_show(False) -- cgit v1.2.3 From 1e4809b251d478a102fd980dcfc26e21d6d3730b Mon Sep 17 00:00:00 2001 From: Alexandre Simard Date: Wed, 19 Oct 2022 12:53:23 -0400 Subject: Added a bit of padding to the left --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index c9a923ab..a2dbd41e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -298,7 +298,7 @@ def check_progress_call(id_part): progressbar = "" if opts.show_progressbar: - progressbar = f"""
{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}
""" + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}
""" image = gr_show(False) preview_visibility = gr_show(False) -- cgit v1.2.3 From 5e012e4dfa5dcfeade0394678cf14b70682dba6c Mon Sep 17 00:00:00 2001 From: timntorres Date: Wed, 19 Oct 2022 06:17:47 -0700 Subject: Infotext saves more specific hypernet name. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index ea926fc3..bcb0c32c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')), + "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), -- cgit v1.2.3 From 46122c4ff6aadc0f96e657f88dbac7bbd9f9bf99 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Wed, 19 Oct 2022 19:18:52 +0300 Subject: Send empty prompts as valid generation parameter --- modules/generation_parameters_copypaste.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index c27826b6..98d24406 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -45,10 +45,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model else: prompt += ("" if prompt == "" else "\n") + line - if len(prompt) > 0: res["Prompt"] = prompt - - if len(negative_prompt) > 0: res["Negative prompt"] = negative_prompt for k, v in re_param.findall(lastline): -- cgit v1.2.3 From 14c1c2b9351f16d43ba4e6b6c9062edad44a6bec Mon Sep 17 00:00:00 2001 From: Alexandre Simard Date: Wed, 19 Oct 2022 13:53:52 -0400 Subject: Show PB texts at same time and earlier For big tasks (1000+ steps), waiting 1 minute to see ETA is long and this changes it so the number of steps done plays a role in showing the text as well. --- modules/ui.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a2dbd41e..0abd177a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -261,14 +261,14 @@ def wrap_gradio_call(func, extra_outputs=None): return f -def calc_time_left(progress, threshold, label, force_display): +def calc_time_left(progress, threshold, label, force_display, showTime): if progress == 0: return "" else: time_since_start = time.time() - shared.state.time_start eta = (time_since_start/progress) eta_relative = eta-time_since_start - if (eta_relative > threshold and progress > 0.02) or force_display: + if (eta_relative > threshold and showTime) or force_display: if eta_relative > 3600: return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) elif eta_relative > 60: @@ -290,7 +290,10 @@ def check_progress_call(id_part): if shared.state.sampling_steps > 0: progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display ) + # Show progress percentage and time left at the same moment, and base it also on steps done + showPBText = progress >= 0.01 or shared.state.sampling_step >= 10 + + time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display, showPBText ) if time_left != "": shared.state.time_left_force_display = True @@ -298,7 +301,7 @@ def check_progress_call(id_part): progressbar = "" if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}
""" + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if showPBText else ""}
""" image = gr_show(False) preview_visibility = gr_show(False) -- cgit v1.2.3 From b748b583c0b9f771c1be509175a6913e3f2ad97c Mon Sep 17 00:00:00 2001 From: Mackerel Date: Wed, 19 Oct 2022 14:22:03 -0400 Subject: generation_parameters_copypaste.py: fix indent --- modules/generation_parameters_copypaste.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 98d24406..0f041449 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -45,8 +45,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model else: prompt += ("" if prompt == "" else "\n") + line - res["Prompt"] = prompt - res["Negative prompt"] = negative_prompt + res["Prompt"] = prompt + res["Negative prompt"] = negative_prompt for k, v in re_param.findall(lastline): m = re_imagesize.match(v) -- cgit v1.2.3 From eb7ba4b713ac2fb960ecf6365b1de0c89451e583 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 19 Oct 2022 19:50:46 +0100 Subject: update training header text --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 1573ef82..93c0767c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1247,7 +1247,7 @@ def create_ui(wrap_gradio_gpu_call): run_preprocess = gr.Button(value="Preprocess", variant='primary') with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork wiki

") + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork [wiki]

") with gr.Row(): train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") -- cgit v1.2.3 From 4d663055ded968831ec97f047dfa8e94036cf1c1 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 19 Oct 2022 20:33:18 +0100 Subject: update ui with extra training options --- modules/ui.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 93c0767c..cdb9d335 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1206,6 +1206,7 @@ def create_ui(wrap_gradio_gpu_call): new_embedding_name = gr.Textbox(label="Name") initialization_text = gr.Textbox(label="Initialization text", value="*") nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") with gr.Row(): with gr.Column(scale=3): @@ -1219,6 +1220,7 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") with gr.Row(): with gr.Column(scale=3): @@ -1247,14 +1249,17 @@ def create_ui(wrap_gradio_gpu_call): run_preprocess = gr.Button(value="Preprocess", variant='primary') with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork [wiki]

") + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") with gr.Row(): train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") with gr.Row(): train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") + with gr.Row(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") + batch_size = gr.Number(label='Batch size', value=1, precision=0) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") @@ -1288,6 +1293,7 @@ def create_ui(wrap_gradio_gpu_call): new_embedding_name, initialization_text, nvpt, + overwrite_old_embedding, ], outputs=[ train_embedding_name, @@ -1303,6 +1309,7 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_sizes, new_hypernetwork_layer_structure, new_hypernetwork_add_layer_norm, + overwrite_old_hypernetwork, ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From 41e3877be2c667316515c86037413763eb0ba4da Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 13:44:59 -0700 Subject: fix entropy point calculation --- modules/textual_inversion/preprocess.py | 34 ++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 168bfb09..7c1a594e 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -196,9 +196,9 @@ def image_focal_points(im): points = cv2.goodFeaturesToTrack( np_im, - maxCorners=50, + maxCorners=100, qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.05, + minDistance=min(grayscale.width, grayscale.height)*0.07, useHarrisDetector=False, ) @@ -218,28 +218,32 @@ def image_focal_points(im): def image_entropy_point(im, crop_width, crop_height): - img = im.copy() - # just make it easier to slide the test crop with images oriented the same way - if (img.size[0] < img.size[1]): - portrait = True - img = img.rotate(90, expand=1) + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] e_max = 0 crop_current = [0, 0, crop_width, crop_height] crop_best = crop_current - while crop_current[2] < img.size[0]: - crop = img.crop(tuple(crop_current)) + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) e = image_entropy(crop) - if (e_max < e): + if (e > e_max): e_max = e crop_best = list(crop_current) - crop_current[0] += 4 - crop_current[2] += 4 + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + crop_width/2) + y_mid = int(crop_best[1] + crop_height/2) - x_mid = int((crop_best[2] - crop_best[0])/2) - y_mid = int((crop_best[3] - crop_best[1])/2) return { 'x': x_mid, @@ -250,7 +254,7 @@ def image_entropy_point(im, crop_width, crop_height): def image_entropy(im): # greyscale image entropy - band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1")) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() -- cgit v1.2.3 From 8e7097d06a6a261580d34375c9d2a9e4ffc63ffa Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 13:47:45 -0700 Subject: Added support for RunwayML inpainting model --- modules/processing.py | 34 ++++++- modules/sd_hijack_inpainting.py | 208 ++++++++++++++++++++++++++++++++++++++++ modules/sd_models.py | 16 +++- modules/sd_samplers.py | 50 +++++++--- 4 files changed, 293 insertions(+), 15 deletions(-) create mode 100644 modules/sd_hijack_inpainting.py (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index bcb0c32c..a6c308f9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -546,7 +546,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) + + # The "masked-image" in this case will just be all zeros since the entire image is masked. + image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device) + image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=image_conditioning) return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -714,10 +723,31 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask + if self.image_mask is not None: + conditioning_mask = np.array(self.image_mask.convert("L")) + conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 + conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: + conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) + + # Create another latent image, this time with a masked version of the original input. + conditioning_mask = conditioning_mask.to(image.device) + conditioning_image = image * (1.0 - conditioning_mask) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + + # Create the concatenated conditioning tensor to be fed to `c_concat` + conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) + conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) + self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) + self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning) + samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py new file mode 100644 index 00000000..7e5670d6 --- /dev/null +++ b/modules/sd_hijack_inpainting.py @@ -0,0 +1,208 @@ +import torch +import numpy as np + +from tqdm import tqdm +from einops import rearrange, repeat +from omegaconf import ListConfig + +from types import MethodType + +import ldm.models.diffusion.ddpm +import ldm.models.diffusion.ddim + +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.models.diffusion.ddim import DDIMSampler, noise_like + +# ================================================================================================= +# Monkey patch DDIMSampler methods from RunwayML repo directly. +# Adapted from: +# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py +# ================================================================================================= +@torch.no_grad() +def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = elf.inpainting_fill == 2: + self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask + elif self.inpainting_fill == 3: + self.init_latent = self.init_latent * self.mask + + if self.image_mask is not None: + conditioning_mask = np.array(self.image_mask.convert("L")) + conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 + conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: + conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) + + # Create another latent image, this time with a masked version of the original input. + conditioning_mask = conditioning_mask.to(image.device) + conditioning_image = image * (1.0 - conditioning_mask) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + + # Create the concatenated conditioning tensor to be fed to `c_concat` + conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) + conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) + self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) + self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) + + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + x = create_random_tensors([opctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + +@torch.no_grad() +def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + +# ================================================================================================= +# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config. +# Adapted from: +# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py +# ================================================================================================= + +@torch.no_grad() +def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) + return c + +class LatentInpaintDiffusion(LatentDiffusion): + def __init__( + self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.concat_keys = concat_keys + +def should_hijack_inpainting(checkpoint_info): + return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml") + +def do_inpainting_hijack(): + ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning + ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion + ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim + ldm.models.diffusion.ddim.DDIMSampler.sample = sample \ No newline at end of file diff --git a/modules/sd_models.py b/modules/sd_models.py index eae22e87..47836d25 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config from modules import shared, modelloader, devices from modules.paths import models_path +from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -211,6 +212,19 @@ def load_model(): print(f"Loading config from: {checkpoint_info.config}") sd_config = OmegaConf.load(checkpoint_info.config) + + if should_hijack_inpainting(checkpoint_info): + do_inpainting_hijack() + + # Hardcoded config for now... + sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + sd_config.model.params.use_ema = False + sd_config.model.params.conditioning_key = "hybrid" + sd_config.model.params.unet_config.params.in_channels = 9 + + # Create a "fake" config with a different name so that we know to unload it when switching models. + checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) @@ -234,7 +248,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() shared.sd_model = load_model() return shared.sd_model diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index b58e810b..9d3cf289 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -136,9 +136,15 @@ class VanillaStableDiffusionSampler: if self.stop_at is not None and self.step > self.stop_at: raise InterruptedException + # Have to unwrap the inpainting conditioning here to perform pre-preocessing + image_conditioning = None + if isinstance(cond, dict): + image_conditioning = cond["c_concat"][0] + cond = cond["c_crossattn"][0] + unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' cond = tensor @@ -157,6 +163,10 @@ class VanillaStableDiffusionSampler: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec + if image_conditioning is not None: + cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) if self.mask is not None: @@ -182,7 +192,7 @@ class VanillaStableDiffusionSampler: self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = setup_img2img_steps(p, steps) self.initialize(p) @@ -202,7 +212,7 @@ class VanillaStableDiffusionSampler: return samples - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): self.initialize(p) self.init_latent = None @@ -210,6 +220,11 @@ class VanillaStableDiffusionSampler: steps = steps or p.steps + # Wrap the conditioning models with additional image conditioning for inpainting model + if image_conditioning is not None: + conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + # existing code fails with certain step counts, like 9 try: samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) @@ -228,7 +243,7 @@ class CFGDenoiser(torch.nn.Module): self.init_latent = None self.step = 0 - def forward(self, x, sigma, uncond, cond, cond_scale): + def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): if state.interrupted or state.skipped: raise InterruptedException @@ -239,28 +254,29 @@ class CFGDenoiser(torch.nn.Module): repeats = [len(conds_list[i]) for i in range(batch_size)] x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) if tensor.shape[1] == uncond.shape[1]: cond_in = torch.cat([tensor, uncond]) if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=cond_in) + x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size for batch_offset in range(0, tensor.shape[0], batch_size): a = batch_offset b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond) + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) @@ -361,7 +377,7 @@ class KDiffusionSampler: return extra_params_kwargs - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = setup_img2img_steps(p, steps) if p.sampler_noise_scheduler_override: @@ -389,11 +405,16 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): steps = steps or p.steps if p.sampler_noise_scheduler_override: @@ -414,7 +435,12 @@ class KDiffusionSampler: else: extra_params_kwargs['sigmas'] = sigmas - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples -- cgit v1.2.3 From 0719c10bf1b817364a498ee11b90d30d3d527344 Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 13:56:26 -0700 Subject: Fixed copying mistake --- modules/sd_hijack_inpainting.py | 79 +++++++++++++---------------------------- 1 file changed, 25 insertions(+), 54 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 7e5670d6..d4d28d2e 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -19,63 +19,35 @@ from ldm.models.diffusion.ddim import DDIMSampler, noise_like # https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py # ================================================================================================= @torch.no_grad() -def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): +def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] while isinstance(ctmp, list): - ctmp = elf.inpainting_fill == 2: - self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask - elif self.inpainting_fill == 3: - self.init_latent = self.init_latent * self.mask - - if self.image_mask is not None: - conditioning_mask = np.array(self.image_mask.convert("L")) - conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 - conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) - - # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 - conditioning_mask = torch.round(conditioning_mask) - else: - conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) - - # Create another latent image, this time with a masked version of the original input. - conditioning_mask = conditioning_mask.to(image.device) - conditioning_image = image * (1.0 - conditioning_mask) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) - - # Create the concatenated conditioning tensor to be fed to `c_concat` - conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) - conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) - self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) - self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) - - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - x = create_random_tensors([opctmp[0] + ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") @@ -106,7 +78,6 @@ def sample( ) return samples, intermediates - @torch.no_grad() def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, -- cgit v1.2.3 From dde9f960727bfe151d418e43685a2881cf580a17 Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 14:14:24 -0700 Subject: added support for ddim img2img --- modules/sd_samplers.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 9d3cf289..d270e4df 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -208,6 +208,12 @@ class VanillaStableDiffusionSampler: self.init_latent = x self.step = 0 + # Wrap the conditioning models with additional image conditioning for inpainting model + if image_conditioning is not None: + conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + + samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) return samples -- cgit v1.2.3 From c418467c03db916c3e5312e6ac4a67365e196dbd Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 15:09:43 -0700 Subject: Don't compute latent mask if were not using it. Also added support for fixed highres_fix generation. --- modules/processing.py | 72 +++++++++++++++++++++++++++++++------------------- modules/sd_samplers.py | 4 +++ 2 files changed, 49 insertions(+), 27 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index a6c308f9..684e5833 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -541,12 +541,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) - - if not self.enable_hr: - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - + def create_dummy_mask(self, x): + if self.sampler.conditioning_key in {'hybrid', 'concat'}: # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device) image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) @@ -555,11 +551,23 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) image_conditioning = image_conditioning.to(x.dtype) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=image_conditioning) + else: + # Dummy zero conditioning if we're not using inpainting model. + # Still takes up a bit of memory, but no encoder call. + image_conditioning = torch.zeros(x.shape[0], 5, x.shape[-2], x.shape[-1], dtype=x.dtype, device=x.device) + + return image_conditioning + + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) + + if not self.enable_hr: + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x)) return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x)) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] @@ -596,7 +604,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples)) return samples @@ -723,26 +731,36 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - if self.image_mask is not None: - conditioning_mask = np.array(self.image_mask.convert("L")) - conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 - conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + conditioning_key = self.sampler.conditioning_key - # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 - conditioning_mask = torch.round(conditioning_mask) + if conditioning_key in {'hybrid', 'concat'}: + if self.image_mask is not None: + conditioning_mask = np.array(self.image_mask.convert("L")) + conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 + conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: + conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) + + # Create another latent image, this time with a masked version of the original input. + conditioning_mask = conditioning_mask.to(image.device) + conditioning_image = image * (1.0 - conditioning_mask) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + + # Create the concatenated conditioning tensor to be fed to `c_concat` + conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) + conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) + self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) + self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) else: - conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) - - # Create another latent image, this time with a masked version of the original input. - conditioning_mask = conditioning_mask.to(image.device) - conditioning_image = image * (1.0 - conditioning_mask) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) - - # Create the concatenated conditioning tensor to be fed to `c_concat` - conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) - conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) - self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) - self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) + self.image_conditioning = torch.zeros( + self.init_latent.shape[0], 5, self.init_latent.shape[-2], self.init_latent.shape[-1], + dtype=self.init_latent.dtype, + device=self.init_latent.device + ) + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d270e4df..c21be26e 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler: self.config = None self.last_latent = None + self.conditioning_key = sd_model.model.conditioning_key + def number_of_needed_noises(self, p): return 0 @@ -328,6 +330,8 @@ class KDiffusionSampler: self.config = None self.last_latent = None + self.conditioning_key = sd_model.model.conditioning_key + def callback_state(self, d): step = d['i'] latent = d["denoised"] -- cgit v1.2.3 From d6ea5841374a28f3f6deb73abc251c8f0bcb240f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:07:57 +0100 Subject: change html output --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d519cd9..73c1cb80 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -380,7 +380,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log Loss: {mean_loss:.7f}
Step: {hypernetwork.step}
Last prompt: {html.escape(entries[0].cond_text)}
-Last saved embedding: {html.escape(last_saved_file)}
+Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" -- cgit v1.2.3 From 166be3919b817cee5e702fd01c34afe9081b952c Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:09:40 +0100 Subject: allow overwrite old hn --- modules/hypernetworks/ui.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 08f75f15..f45345ea 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -10,9 +10,10 @@ from modules import sd_hijack, shared, devices from modules.hypernetworks import hypernetwork -def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False): +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") - assert not os.path.exists(fn), f"file {fn} already exists" + if not overwrite_old: + assert not os.path.exists(fn), f"file {fn} already exists" if type(layer_structure) == str: layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure))) -- cgit v1.2.3 From 0087079c2d487b67b06ffc30f36ce486a74e6318 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:10:59 +0100 Subject: allow overwrite old embedding --- modules/textual_inversion/textual_inversion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 3be69562..5776778b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -153,7 +153,7 @@ class EmbeddingDatabase: return None, None -def create_embedding(name, num_vectors_per_token, init_text='*'): +def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): cond_model = shared.sd_model.cond_stage_model embedding_layer = cond_model.wrapped.transformer.text_model.embeddings @@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") - assert not os.path.exists(fn), f"file {fn} already exists" + if not overwrite_old: + assert not os.path.exists(fn), f"file {fn} already exists" embedding = Embedding(vec, name) embedding.step = 0 -- cgit v1.2.3 From 632e8d660293081cadb145d8062e5aff0a4a8f0d Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:19:40 +0100 Subject: split learn rates --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index cdb9d335..d07184ee 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1342,7 +1342,7 @@ def create_ui(wrap_gradio_gpu_call): _js="start_training_textual_inversion", inputs=[ train_embedding_name, - learn_rate, + embedding_learn_rate, batch_size, dataset_directory, log_directory, @@ -1367,7 +1367,7 @@ def create_ui(wrap_gradio_gpu_call): _js="start_training_textual_inversion", inputs=[ train_hypernetwork_name, - learn_rate, + hypernetwork_learn_rate, batch_size, dataset_directory, log_directory, -- cgit v1.2.3 From c3835ec85cbb44fa3c46fa871c622b6fee235c89 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:24:24 +0100 Subject: pass overwrite old flag --- modules/textual_inversion/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 36881e7a..e712284d 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess from modules import sd_hijack, shared -def create_embedding(name, initialization_text, nvpt): - filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text) +def create_embedding(name, initialization_text, nvpt, overwrite_old): + filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() -- cgit v1.2.3 From 4d6b9f76a55fd0ac0f72634071032dd9c6efb409 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:27:16 +0100 Subject: reorder create_hypernetwork params --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d07184ee..322c082b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1307,9 +1307,9 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ new_hypernetwork_name, new_hypernetwork_sizes, + overwrite_old_hypernetwork, new_hypernetwork_layer_structure, new_hypernetwork_add_layer_norm, - overwrite_old_hypernetwork, ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From fbcce66601994f6ed370db36d9c238840fed6bd2 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:46:54 +0100 Subject: add existing caption file handling --- modules/textual_inversion/preprocess.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 886cf0c3..5c43fe13 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro shared.state.textinfo = "Preprocessing..." shared.state.job_count = len(files) - def save_pic_with_caption(image, index): + def save_pic_with_caption(image, index, existing_caption=None): caption = "" if process_caption: @@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro basename = f"{index:05}-{subindex[0]}-{filename_part}" image.save(os.path.join(dst, f"{basename}.png")) + if preprocess_txt_action == 'prepend' and existing_caption: + caption = existing_caption + ' ' + caption + elif preprocess_txt_action == 'append' and existing_caption: + caption = caption + ' ' + existing_caption + elif preprocess_txt_action == 'copy' and existing_caption: + caption = existing_caption + + caption = caption.strip() + if len(caption) > 0: with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: file.write(caption) subindex[0] += 1 - def save_pic(image, index): + def save_pic(image, index, existing_caption=None): save_pic_with_caption(image, index) if process_flip: - save_pic_with_caption(ImageOps.mirror(image), index) + save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption) for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] @@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro except Exception: continue + existing_caption = None + + try: + existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read() + except Exception as e: + print(e) + if shared.state.interrupted: break @@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro img = img.resize((width, height * img.height // img.width)) top = img.crop((0, 0, width, height)) - save_pic(top, index) + save_pic(top, index, existing_caption=existing_caption) bot = img.crop((0, img.height - height, width, img.height)) - save_pic(bot, index) + save_pic(bot, index, existing_caption=existing_caption) elif process_split and is_wide: img = img.resize((width * img.width // img.height, height)) left = img.crop((0, 0, width, height)) - save_pic(left, index) + save_pic(left, index, existing_caption=existing_caption) right = img.crop((img.width - width, 0, img.width, height)) - save_pic(right, index) + save_pic(right, index, existing_caption=existing_caption) else: img = images.resize_image(1, img, width, height) - save_pic(img, index) + save_pic(img, index, existing_caption=existing_caption) shared.state.nextjob() -- cgit v1.2.3 From ab353b141df8eee042b0964bcb645015dabf3459 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:48:07 +0100 Subject: link existing txt option --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 322c082b..7f52ac0c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1234,6 +1234,7 @@ def create_ui(wrap_gradio_gpu_call): process_dst = gr.Textbox(label='Destination directory') process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', choices=['ignore', 'copy', 'prepend', 'append']) with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') @@ -1326,6 +1327,7 @@ def create_ui(wrap_gradio_gpu_call): process_dst, process_width, process_height, + preprocess_txt_action, process_flip, process_split, process_caption, -- cgit v1.2.3 From 9b65c4ecf4f8eb6187ee721918adebe68e9bc631 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:49:23 +0100 Subject: pass preprocess_txt_action param --- modules/textual_inversion/preprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 5c43fe13..3713bc89 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False): try: if process_caption: shared.interrogator.load() @@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru) finally: @@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ -def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False): width = process_width height = process_height src = os.path.abspath(process_src) -- cgit v1.2.3 From 55d8c6cce6d3aef848b9f194adad2ce53064d8b7 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:53:29 +0100 Subject: default to ignore existing captions --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 7f52ac0c..bd5f1b05 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1234,7 +1234,7 @@ def create_ui(wrap_gradio_gpu_call): process_dst = gr.Textbox(label='Destination directory') process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', choices=['ignore', 'copy', 'prepend', 'append']) + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') -- cgit v1.2.3 From 8b74b9aa9a20e4c5c1f72641f8b9617479eb276b Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Wed, 19 Oct 2022 19:06:14 -0500 Subject: add symbol for clear button and simplify roll_col css selector --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a2dbd41e..9f6edc5f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -83,6 +83,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 +trash_prompt_symbol = '\U0001F5D1' # 🗑🗑🗑 def plaintext_to_html(text): @@ -498,6 +499,7 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt") token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") -- cgit v1.2.3 From 6f98e89486f55b0e4657e96ce640cf1c4675d187 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Thu, 20 Oct 2022 00:10:45 +0000 Subject: update --- modules/hypernetworks/hypernetwork.py | 29 +++++++++++++++-------- modules/hypernetworks/ui.py | 3 ++- modules/ui.py | 43 +++++++++++++++++++---------------- 3 files changed, 44 insertions(+), 31 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74300122..7d617680 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -22,16 +22,20 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): + def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): super().__init__() - assert layer_structure is not None, "layer_structure mut not be None" + assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" linears = [] for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + if activation_func == "relu": + linears.append(torch.nn.ReLU()) + if activation_func == "leakyrelu": + linears.append(torch.nn.LeakyReLU()) if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -42,8 +46,9 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - layer.weight.data.normal_(mean=0.0, std=0.01) - layer.bias.data.zero_() + if not "ReLU" in layer.__str__(): + layer.weight.data.normal_(mean=0.0, std=0.01) + layer.bias.data.zero_() self.to(devices.device) @@ -69,7 +74,8 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - layer_structure += [layer.weight, layer.bias] + if not "ReLU" in layer.__str__(): + layer_structure += [layer.weight, layer.bias] return layer_structure @@ -81,7 +87,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None): self.filename = None self.name = name self.layers = {} @@ -90,11 +96,12 @@ class Hypernetwork: self.sd_checkpoint_name = None self.layer_structure = layer_structure self.add_layer_norm = add_layer_norm + self.activation_func = activation_func for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), ) def weights(self): @@ -117,6 +124,7 @@ class Hypernetwork: state_dict['name'] = self.name state_dict['layer_structure'] = self.layer_structure state_dict['is_layer_norm'] = self.add_layer_norm + state_dict['activation_func'] = self.activation_func state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name @@ -131,12 +139,13 @@ class Hypernetwork: self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) self.add_layer_norm = state_dict.get('is_layer_norm', False) + self.activation_func = state_dict.get('activation_func', None) for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm), - HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func), + HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func), ) self.name = state_dict.get('name', self.name) diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 08f75f15..83f9547b 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -10,7 +10,7 @@ from modules import sd_hijack, shared, devices from modules.hypernetworks import hypernetwork -def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False): +def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" @@ -22,6 +22,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm enable_sizes=[int(x) for x in enable_sizes], layer_structure=layer_structure, add_layer_norm=add_layer_norm, + activation_func=activation_func, ) hypernet.save(fn) diff --git a/modules/ui.py b/modules/ui.py index d2e24880..8751fa9c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,43 +5,44 @@ import json import math import mimetypes import os +import platform import random +import subprocess as sp import sys import tempfile import time import traceback -import platform -import subprocess as sp from functools import partial, reduce +import gradio as gr +import gradio.routes +import gradio.utils import numpy as np +import piexif import torch from PIL import Image, PngImagePlugin -import piexif -import gradio as gr -import gradio.utils -import gradio.routes - -from modules import sd_hijack, sd_models, localization +from modules import localization, sd_hijack, sd_models from modules.paths import script_path -from modules.shared import opts, cmd_opts, restricted_opts +from modules.shared import cmd_opts, opts, restricted_opts + if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags -import modules.shared as shared -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.sd_hijack import model_hijack + +import modules.codeformer_model +import modules.generation_parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.images_history as img_his import modules.ldsr_model import modules.scripts -import modules.gfpgan_model -import modules.codeformer_model +import modules.shared as shared import modules.styles -import modules.generation_parameters_copypaste +import modules.textual_inversion.ui from modules import prompt_parser from modules.images import save_image -import modules.textual_inversion.ui -import modules.hypernetworks.ui -import modules.images_history as img_his +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -268,8 +269,8 @@ def calc_time_left(progress, threshold, label, force_display): time_since_start = time.time() - shared.state.time_start eta = (time_since_start/progress) eta_relative = eta-time_since_start - if (eta_relative > threshold and progress > 0.02) or force_display: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + if (eta_relative > threshold and progress > 0.02) or force_display: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) else: return "" @@ -1219,6 +1220,7 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") + new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["relu", "leakyrelu"]) with gr.Row(): with gr.Column(scale=3): @@ -1303,6 +1305,7 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_sizes, new_hypernetwork_layer_structure, new_hypernetwork_add_layer_norm, + new_hypernetwork_activation_func, ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From ba469343e6a1c6e23e82acf5feb65c6101dacbb2 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Thu, 20 Oct 2022 00:17:04 +0000 Subject: align ui.py imports with upstream --- modules/ui.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 987b1d7d..913b23b4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,44 +5,43 @@ import json import math import mimetypes import os -import platform import random -import subprocess as sp import sys import tempfile import time import traceback +import platform +import subprocess as sp from functools import partial, reduce -import gradio as gr -import gradio.routes -import gradio.utils import numpy as np -import piexif import torch from PIL import Image, PngImagePlugin +import piexif -from modules import localization, sd_hijack, sd_models -from modules.paths import script_path -from modules.shared import cmd_opts, opts, restricted_opts +import gradio as gr +import gradio.utils +import gradio.routes +from modules import sd_hijack, sd_models, localization +from modules.paths import script_path +from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags - -import modules.codeformer_model -import modules.generation_parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.images_history as img_his +import modules.shared as shared +from modules.sd_samplers import samplers, samplers_for_img2img +from modules.sd_hijack import model_hijack import modules.ldsr_model import modules.scripts -import modules.shared as shared +import modules.gfpgan_model +import modules.codeformer_model import modules.styles -import modules.textual_inversion.ui +import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img +import modules.textual_inversion.ui +import modules.hypernetworks.ui +import modules.images_history as img_his # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() -- cgit v1.2.3 From 59ed74438318af893d2cba552b0e28dbc2a9266c Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 17:19:02 -0700 Subject: face detection algo, configurability, reusability Try to move the crop in the direction of a face if it is present More internal configuration options for choosing weights of each of the algorithm's findings Move logic into its module --- modules/textual_inversion/autocrop.py | 216 ++++++++++++++++++++++++++++++++ modules/textual_inversion/preprocess.py | 150 +++------------------- 2 files changed, 230 insertions(+), 136 deletions(-) create mode 100644 modules/textual_inversion/autocrop.py (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py new file mode 100644 index 00000000..f858a958 --- /dev/null +++ b/modules/textual_inversion/autocrop.py @@ -0,0 +1,216 @@ +import cv2 +from collections import defaultdict +from math import log, sqrt +import numpy as np +from PIL import Image, ImageDraw + +GREEN = "#0F0" +BLUE = "#00F" +RED = "#F00" + +def crop_image(im, settings): + """ Intelligently crop an image to the subject matter """ + if im.height > im.width: + im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) + else: + im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) + + focus = focal_point(im, settings) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(settings.crop_height / 2) + x_half = int(settings.crop_width / 2) + + x1 = focus.x - x_half + if x1 < 0: + x1 = 0 + elif x1 + settings.crop_width > im.width: + x1 = im.width - settings.crop_width + + y1 = focus.y - y_half + if y1 < 0: + y1 = 0 + elif y1 + settings.crop_height > im.height: + y1 = im.height - settings.crop_height + + x2 = x1 + settings.crop_width + y2 = y1 + settings.crop_height + + crop = [x1, y1, x2, y2] + + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = list(crop) + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + + return im.crop(tuple(crop)) + +def focal_point(im, settings): + corner_points = image_corner_points(im, settings) + entropy_points = image_entropy_points(im, settings) + face_points = image_face_points(im, settings) + + total_points = len(corner_points) + len(entropy_points) + len(face_points) + + corner_weight = settings.corner_points_weight + entropy_weight = settings.entropy_points_weight + face_weight = settings.face_points_weight + + weight_pref_total = corner_weight + entropy_weight + face_weight + + # weight things + pois = [] + if weight_pref_total == 0 or total_points == 0: + return pois + + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] + ) + + if settings.annotate_image: + d = ImageDraw.Draw(im) + + average_point = poi_average(pois, settings, im=im) + + if settings.annotate_image: + d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN) + + return average_point + + +def image_face_points(im, settings): + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml') + + minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side + faces = classifier.detectMultiScale(gray, scaleFactor=1.05, + minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + + if len(faces) == 0: + return [] + + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + if settings.annotate_image: + for f in rects: + d = ImageDraw.Draw(im) + d.rectangle(f, outline=RED) + + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects] + + +def image_corner_points(im, settings): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=100, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.07, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y)) + + return focal_points + + +def image_entropy_points(im, settings): + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] + else: + return [] + + e_max = 0 + crop_current = [0, 0, settings.crop_width, settings.crop_height] + crop_best = crop_current + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e > e_max): + e_max = e + crop_best = list(crop_current) + + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + settings.crop_width/2) + y_mid = int(crop_best[1] + settings.crop_height/2) + + return [PointOfInterest(x_mid, y_mid)] + + +def image_entropy(im): + # greyscale image entropy + band = np.asarray(im.convert("1")) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + + +def poi_average(pois, settings, im=None): + weight = 0.0 + x = 0.0 + y = 0.0 + for pois in pois: + if settings.annotate_image and im is not None: + w = 4 * 0.5 * sqrt(pois.weight) + d = ImageDraw.Draw(im) + d.ellipse([ + pois.x - w, pois.y - w, + pois.x + w, pois.y + w ], fill=BLUE) + weight += pois.weight + x += pois.x * pois.weight + y += pois.y * pois.weight + avg_x = round(x / weight) + avg_y = round(y / weight) + + return PointOfInterest(avg_x, avg_y) + + +class PointOfInterest: + def __init__(self, x, y, weight=1.0): + self.x = x + self.y = y + self.weight = weight + + +class Settings: + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): + self.crop_width = crop_width + self.crop_height = crop_height + self.corner_points_weight = corner_points_weight + self.entropy_points_weight = entropy_points_weight + self.face_points_weight = entropy_points_weight + self.annotate_image = annotate_image + self.destop_view_image = False \ No newline at end of file diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 7c1a594e..0c79f012 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,7 +1,5 @@ import os -import cv2 -import numpy as np -from PIL import Image, ImageOps, ImageDraw +from PIL import Image, ImageOps import platform import sys import tqdm @@ -9,6 +7,7 @@ import time from modules import shared, images from modules.shared import opts, cmd_opts +from modules.textual_inversion import autocrop if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru @@ -80,6 +79,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro if process_flip: save_pic_with_caption(ImageOps.mirror(image), index) + for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] filename = os.path.join(src, imagefile) @@ -118,37 +118,16 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro processing_option_ran = True - if process_entropy_focus and (is_tall or is_wide): - if is_tall: - img = img.resize((width, height * img.height // img.width)) - else: - img = img.resize((width * img.width // img.height, height)) - - x_focal_center, y_focal_center = image_central_focal_point(img, width, height) - - # take the focal point and turn it into crop coordinates that try to center over the focal - # point but then get adjusted back into the frame - y_half = int(height / 2) - x_half = int(width / 2) - - x1 = x_focal_center - x_half - if x1 < 0: - x1 = 0 - elif x1 + width > img.width: - x1 = img.width - width - - y1 = y_focal_center - y_half - if y1 < 0: - y1 = 0 - elif y1 + height > img.height: - y1 = img.height - height - - x2 = x1 + width - y2 = y1 + height - - crop = [x1, y1, x2, y2] - - focal = img.crop(tuple(crop)) + if process_entropy_focus and img.height != img.width: + autocrop_settings = autocrop.Settings( + crop_width = width, + crop_height = height, + face_points_weight = 0.9, + entropy_points_weight = 0.7, + corner_points_weight = 0.5, + annotate_image = False + ) + focal = autocrop.crop_image(img, autocrop_settings) save_pic(focal, index) processing_option_ran = True @@ -157,105 +136,4 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro img = images.resize_image(1, img, width, height) save_pic(img, index) - shared.state.nextjob() - - -def image_central_focal_point(im, target_width, target_height): - focal_points = [] - - focal_points.extend( - image_focal_points(im) - ) - - fp_entropy = image_entropy_point(im, target_width, target_height) - fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy - - focal_points.append(fp_entropy) - - weight = 0.0 - x = 0.0 - y = 0.0 - for focal_point in focal_points: - weight += focal_point['weight'] - x += focal_point['x'] * focal_point['weight'] - y += focal_point['y'] * focal_point['weight'] - avg_x = round(x // weight) - avg_y = round(y // weight) - - return avg_x, avg_y - - -def image_focal_points(im): - grayscale = im.convert("L") - - # naive attempt at preventing focal points from collecting at watermarks near the bottom - gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") - - np_im = np.array(grayscale) - - points = cv2.goodFeaturesToTrack( - np_im, - maxCorners=100, - qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.07, - useHarrisDetector=False, - ) - - if points is None: - return [] - - focal_points = [] - for point in points: - x, y = point.ravel() - focal_points.append({ - 'x': x, - 'y': y, - 'weight': 1.0 - }) - - return focal_points - - -def image_entropy_point(im, crop_width, crop_height): - landscape = im.height < im.width - portrait = im.height > im.width - if landscape: - move_idx = [0, 2] - move_max = im.size[0] - elif portrait: - move_idx = [1, 3] - move_max = im.size[1] - - e_max = 0 - crop_current = [0, 0, crop_width, crop_height] - crop_best = crop_current - while crop_current[move_idx[1]] < move_max: - crop = im.crop(tuple(crop_current)) - e = image_entropy(crop) - - if (e > e_max): - e_max = e - crop_best = list(crop_current) - - crop_current[move_idx[0]] += 4 - crop_current[move_idx[1]] += 4 - - x_mid = int(crop_best[0] + crop_width/2) - y_mid = int(crop_best[1] + crop_height/2) - - - return { - 'x': x_mid, - 'y': y_mid, - 'weight': 1.0 - } - - -def image_entropy(im): - # greyscale image entropy - band = np.asarray(im.convert("1")) - hist, _ = np.histogram(band, bins=range(0, 256)) - hist = hist[hist > 0] - return -np.log2(hist / hist.sum()).sum() - + shared.state.nextjob() \ No newline at end of file -- cgit v1.2.3 From 858462f719c22ca9f24b94a41699653c34b5f4fb Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 02:57:18 +0100 Subject: do caption copy for both flips --- modules/textual_inversion/preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 3713bc89..6bba3852 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -82,7 +82,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre subindex[0] += 1 def save_pic(image, index, existing_caption=None): - save_pic_with_caption(image, index) + save_pic_with_caption(image, index, existing_caption=existing_caption) if process_flip: save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption) -- cgit v1.2.3 From c6345bd445463b7aa41723d6637e80dfa293a890 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Wed, 19 Oct 2022 21:23:57 -0500 Subject: nerf line length --- modules/ui.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 9f6edc5f..cb9a6c6e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -83,7 +83,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 -trash_prompt_symbol = '\U0001F5D1' # 🗑🗑🗑 +trash_prompt_symbol = '\U0001F5D1' # def plaintext_to_html(text): @@ -617,7 +617,10 @@ def create_ui(wrap_gradio_gpu_call): return refresh_button with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\ + txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\ + token_button = create_toprow(is_img2img=False) + dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.3 From aa7ff2a1972f3865883e10ba28c5414cdebe8e3b Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 21:46:13 -0700 Subject: Fixed non-square highres fix generation --- modules/processing.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 684e5833..3caac25e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -541,10 +541,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def create_dummy_mask(self, x): + def create_dummy_mask(self, x, first_phase: bool = False): if self.sampler.conditioning_key in {'hybrid', 'concat'}: + height = self.firstphase_height if first_phase else self.height + width = self.firstphase_width if first_phase else self.width + # The "masked-image" in this case will just be all zeros since the entire image is masked. - image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device) + image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) # Add the fake full 1s mask to the first dimension. @@ -567,7 +570,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x)) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, first_phase=True)) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] -- cgit v1.2.3 From 930b4c64f7dbce6918894d53538003e5959fd022 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 20 Oct 2022 08:18:02 +0300 Subject: allow float sizes for hypernet's layer_structure --- modules/hypernetworks/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 08f75f15..e0741d08 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -15,7 +15,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm assert not os.path.exists(fn), f"file {fn} already exists" if type(layer_structure) == str: - layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure))) + layer_structure = [float(x.strip()) for x in layer_structure.split(",")] hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( name=name, -- cgit v1.2.3 From 158d678f596d7fc304a6ce2f0dc31f8abfe62250 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Thu, 20 Oct 2022 01:08:24 -0500 Subject: clear prompt button now works on both relevant tabs. Device detection stuff will be added later. --- modules/ui.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index cb9a6c6e..bde546cc 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -424,6 +424,16 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox +# setup button for clearing prompt input boxes on client side of webui +def connect_trash_prompt(dummy_component, button, is_img2img): + + button.click( + fn=lambda: print("Clearing prompt"), + _js="trash_prompt", + inputs=[], + outputs=[], + ) + def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): """ Connects a 'reuse (sub)seed' button's click event so that it copies last used (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength @@ -540,7 +550,7 @@ def create_toprow(is_img2img): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style2.save_to_config = True - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, trash_prompt def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -619,10 +629,11 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\ txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\ - token_button = create_toprow(is_img2img=False) + token_button, trash_prompt_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + connect_trash_prompt(dummy_component, trash_prompt_button, False) with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): @@ -807,7 +818,11 @@ def create_ui(wrap_gradio_gpu_call): token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\ + img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ + token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) + + connect_trash_prompt(dummy_component,trash_prompt_button, True) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.3 From 0ddaf8d2028a7251e8c4ad93551a43b5d4700841 Mon Sep 17 00:00:00 2001 From: captin411 Date: Thu, 20 Oct 2022 00:34:55 -0700 Subject: improve face detection a lot --- modules/textual_inversion/autocrop.py | 99 ++++++++++++++++++++++------------- 1 file changed, 62 insertions(+), 37 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index f858a958..5a551c25 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -8,12 +8,18 @@ GREEN = "#0F0" BLUE = "#00F" RED = "#F00" + def crop_image(im, settings): """ Intelligently crop an image to the subject matter """ if im.height > im.width: im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) - else: + elif im.width > im.height: im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) + else: + im = im.resize((settings.crop_width, settings.crop_height)) + + if im.height == im.width: + return im focus = focal_point(im, settings) @@ -78,13 +84,18 @@ def focal_point(im, settings): [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] ) - if settings.annotate_image: - d = ImageDraw.Draw(im) - - average_point = poi_average(pois, settings, im=im) + average_point = poi_average(pois, settings) if settings.annotate_image: - d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN) + d = ImageDraw.Draw(im) + for f in face_points: + d.rectangle(f.bounding(f.size), outline=RED) + for f in entropy_points: + d.rectangle(f.bounding(30), outline=BLUE) + for poi in pois: + w = max(4, 4 * 0.5 * sqrt(poi.weight)) + d.ellipse(poi.bounding(w), fill=BLUE) + d.ellipse(average_point.bounding(25), outline=GREEN) return average_point @@ -92,22 +103,32 @@ def focal_point(im, settings): def image_face_points(im, settings): np_im = np.array(im) gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml') - - minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side - faces = classifier.detectMultiScale(gray, scaleFactor=1.05, - minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - if len(faces) == 0: - return [] - - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - if settings.annotate_image: - for f in rects: - d = ImageDraw.Draw(im) - d.rectangle(f, outline=RED) - - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects] + tries = [ + [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] + ] + + for t in tries: + # print(t[0]) + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + except: + continue + + if len(faces) > 0: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] + return [] def image_corner_points(im, settings): @@ -132,8 +153,8 @@ def image_corner_points(im, settings): focal_points = [] for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y)) + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4)) return focal_points @@ -167,31 +188,26 @@ def image_entropy_points(im, settings): x_mid = int(crop_best[0] + settings.crop_width/2) y_mid = int(crop_best[1] + settings.crop_height/2) - return [PointOfInterest(x_mid, y_mid)] + return [PointOfInterest(x_mid, y_mid, size=25)] def image_entropy(im): # greyscale image entropy - band = np.asarray(im.convert("1")) + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() -def poi_average(pois, settings, im=None): +def poi_average(pois, settings): weight = 0.0 x = 0.0 y = 0.0 - for pois in pois: - if settings.annotate_image and im is not None: - w = 4 * 0.5 * sqrt(pois.weight) - d = ImageDraw.Draw(im) - d.ellipse([ - pois.x - w, pois.y - w, - pois.x + w, pois.y + w ], fill=BLUE) - weight += pois.weight - x += pois.x * pois.weight - y += pois.y * pois.weight + for poi in pois: + weight += poi.weight + x += poi.x * poi.weight + y += poi.y * poi.weight avg_x = round(x / weight) avg_y = round(y / weight) @@ -199,10 +215,19 @@ def poi_average(pois, settings, im=None): class PointOfInterest: - def __init__(self, x, y, weight=1.0): + def __init__(self, x, y, weight=1.0, size=10): self.x = x self.y = y self.weight = weight + self.size = size + + def bounding(self, size): + return [ + self.x - size//2, + self.y - size//2, + self.x + size//2, + self.y + size//2 + ] class Settings: -- cgit v1.2.3 From f8733ad08be08bafb40f4299785590e11f049e96 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Thu, 20 Oct 2022 11:07:37 +0000 Subject: add linear as a act func (option for doin nothing) --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 913b23b4..716f14b8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1224,7 +1224,7 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") - new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["relu", "leakyrelu"]) + new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"]) with gr.Row(): with gr.Column(scale=3): -- cgit v1.2.3 From 9681419e422515e42444e0174355b760645a846f Mon Sep 17 00:00:00 2001 From: Milly Date: Thu, 20 Oct 2022 16:53:46 +0900 Subject: train: fixed preprocess image ratio --- modules/textual_inversion/preprocess.py | 54 +++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 886cf0c3..2743bdeb 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,5 +1,6 @@ import os from PIL import Image, ImageOps +import math import platform import sys import tqdm @@ -38,6 +39,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) + split_threshold = 0.5 + overlap_ratio = 0.2 assert src != dst, 'same directory specified as source and destination' @@ -78,6 +81,29 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro if process_flip: save_pic_with_caption(ImageOps.mirror(image), index) + def split_pic(image, inverse_xy): + if inverse_xy: + from_w, from_h = image.height, image.width + to_w, to_h = height, width + else: + from_w, from_h = image.width, image.height + to_w, to_h = width, height + h = from_h * to_w // from_w + if inverse_xy: + image = image.resize((h, to_w)) + else: + image = image.resize((to_w, h)) + + split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) + y_step = (h - to_h) / (split_count - 1) + for i in range(split_count): + y = int(y_step * i) + if inverse_xy: + splitted = image.crop((y, 0, y + to_h, to_w)) + else: + splitted = image.crop((0, y, to_w, y + to_h)) + yield splitted + for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] filename = os.path.join(src, imagefile) @@ -89,26 +115,16 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro if shared.state.interrupted: break - ratio = img.height / img.width - is_tall = ratio > 1.35 - is_wide = ratio < 1 / 1.35 - - if process_split and is_tall: - img = img.resize((width, height * img.height // img.width)) - - top = img.crop((0, 0, width, height)) - save_pic(top, index) - - bot = img.crop((0, img.height - height, width, img.height)) - save_pic(bot, index) - elif process_split and is_wide: - img = img.resize((width * img.width // img.height, height)) - - left = img.crop((0, 0, width, height)) - save_pic(left, index) + if img.height > img.width: + ratio = (img.width * height) / (img.height * width) + inverse_xy = False + else: + ratio = (img.height * width) / (img.width * height) + inverse_xy = True - right = img.crop((img.width - width, 0, img.width, height)) - save_pic(right, index) + if process_split and ratio < 1.0 and ratio <= split_threshold: + for splitted in split_pic(img, inverse_xy): + save_pic(splitted, index) else: img = images.resize_image(1, img, width, height) save_pic(img, index) -- cgit v1.2.3 From 85dd62c4c7635b8e21a75f140d093036069e97a1 Mon Sep 17 00:00:00 2001 From: Milly Date: Thu, 20 Oct 2022 22:56:45 +0900 Subject: train: ui: added `Split image threshold` and `Split image overlap ratio` to preprocess --- modules/textual_inversion/preprocess.py | 10 +++++----- modules/ui.py | 16 ++++++++++++++-- 2 files changed, 19 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 2743bdeb..c8df8aa0 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,7 +12,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): try: if process_caption: shared.interrogator.load() @@ -22,7 +22,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio) finally: @@ -34,13 +34,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ -def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): width = process_width height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) - split_threshold = 0.5 - overlap_ratio = 0.2 + split_threshold = max(0.0, min(1.0, split_threshold)) + overlap_ratio = max(0.0, min(0.9, overlap_ratio)) assert src != dst, 'same directory specified as source and destination' diff --git a/modules/ui.py b/modules/ui.py index a2dbd41e..bc7f3330 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1240,10 +1240,14 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') - process_split = gr.Checkbox(label='Split oversized images into two') + process_split = gr.Checkbox(label='Split oversized images') process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") @@ -1251,6 +1255,12 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): run_preprocess = gr.Button(value="Preprocess", variant='primary') + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") with gr.Row(): @@ -1327,7 +1337,9 @@ def create_ui(wrap_gradio_gpu_call): process_flip, process_split, process_caption, - process_caption_deepbooru + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, ], outputs=[ ti_output, -- cgit v1.2.3 From d8acd34f66ab35a91f10d66330bcc95a83bfcac6 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Thu, 20 Oct 2022 23:43:03 +0900 Subject: generalized some functions and option for ignoring first layer --- modules/hypernetworks/hypernetwork.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d617680..3a44b377 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -21,21 +21,27 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - + activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish} + def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): super().__init__() assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - + linears = [] for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - if activation_func == "relu": - linears.append(torch.nn.ReLU()) - if activation_func == "leakyrelu": - linears.append(torch.nn.LeakyReLU()) + # if skip_first_layer because first parameters potentially contain negative values + if i < 1: continue + if activation_func in HypernetworkModule.activation_dict: + linears.append(HypernetworkModule.activation_dict[activation_func]()) + else: + print("Invalid key {} encountered as activation function!".format(activation_func)) + # if use_dropout: + linears.append(torch.nn.Dropout(p=0.3)) if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - if not "ReLU" in layer.__str__(): + if isinstance(layer, torch.nn.Linear): layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() @@ -298,7 +304,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) + # if optimizer == "Adam": or else Adam / AdamW / etc... + optimizer = torch.optim.Adam(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: -- cgit v1.2.3 From a71e0212363979c7cbbb797c9fbd5f8cd03b29d3 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Thu, 20 Oct 2022 23:48:52 +0900 Subject: only linear --- modules/hypernetworks/hypernetwork.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3a44b377..905cbeef 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -35,13 +35,13 @@ class HypernetworkModule(torch.nn.Module): for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # if skip_first_layer because first parameters potentially contain negative values - if i < 1: continue + # if i < 1: continue if activation_func in HypernetworkModule.activation_dict: linears.append(HypernetworkModule.activation_dict[activation_func]()) else: print("Invalid key {} encountered as activation function!".format(activation_func)) # if use_dropout: - linears.append(torch.nn.Dropout(p=0.3)) + # linears.append(torch.nn.Dropout(p=0.3)) if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - if not "ReLU" in layer.__str__(): + if isinstance(layer, torch.nn.Linear): layer_structure += [layer.weight, layer.bias] return layer_structure @@ -304,8 +304,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - # if optimizer == "Adam": or else Adam / AdamW / etc... - optimizer = torch.optim.Adam(weights, lr=scheduler.learn_rate) + # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... + optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: -- cgit v1.2.3 From d07cb46f34b3d9fe7a78b102f899ebef352ea56b Mon Sep 17 00:00:00 2001 From: yfszzx Date: Thu, 20 Oct 2022 23:58:52 +0800 Subject: inspiration pull request --- modules/inspiration.py | 122 +++++++++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + modules/ui.py | 13 ++++-- 3 files changed, 131 insertions(+), 5 deletions(-) create mode 100644 modules/inspiration.py (limited to 'modules') diff --git a/modules/inspiration.py b/modules/inspiration.py new file mode 100644 index 00000000..456bfcb5 --- /dev/null +++ b/modules/inspiration.py @@ -0,0 +1,122 @@ +import os +import random +import gradio +inspiration_path = "inspiration" +inspiration_system_path = os.path.join(inspiration_path, "system") +def read_name_list(file): + if not os.path.exists(file): + return [] + f = open(file, "r") + ret = [] + line = f.readline() + while len(line) > 0: + line = line.rstrip("\n") + ret.append(line) + print(ret) + return ret + +def save_name_list(file, name): + print(file) + f = open(file, "a") + f.write(name + "\n") + +def get_inspiration_images(source, types): + path = os.path.join(inspiration_path , types) + if source == "Favorites": + names = read_name_list(os.path.join(inspiration_system_path, types + "_faverites.txt")) + names = random.sample(names, 25) + elif source == "Abandoned": + names = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) + names = random.sample(names, 25) + elif source == "Exclude abandoned": + abondened = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) + all_names = os.listdir(path) + names = [] + while len(names) < 25: + name = random.choice(all_names) + if name not in abondened: + names.append(name) + else: + names = random.sample(os.listdir(path), 25) + names = random.sample(names, 25) + image_list = [] + for a in names: + image_path = os.path.join(path, a) + images = os.listdir(image_path) + image_list.append(os.path.join(image_path, random.choice(images))) + return image_list, names + +def select_click(index, types, name_list): + name = name_list[int(index)] + path = os.path.join(inspiration_path, types, name) + images = os.listdir(path) + return name, [os.path.join(path, x) for x in images] + +def give_up_click(name, types): + file = os.path.join(inspiration_system_path, types + "_abandoned.txt") + name_list = read_name_list(file) + if name not in name_list: + save_name_list(file, name) + +def collect_click(name, types): + file = os.path.join(inspiration_system_path, types + "_faverites.txt") + print(file) + name_list = read_name_list(file) + print(name_list) + if name not in name_list: + save_name_list(file, name) + +def moveout_click(name, types): + file = os.path.join(inspiration_system_path, types + "_faverites.txt") + name_list = read_name_list(file) + if name not in name_list: + save_name_list(file, name) + +def source_change(source): + if source == "Abandoned" or source == "Favorites": + return gradio.Button.update(visible=True, value=f"Move out {source}") + else: + return gradio.Button.update(visible=False) + +def ui(gr, opts): + with gr.Blocks(analytics_enabled=False) as inspiration: + flag = os.path.exists(inspiration_path) + if flag: + types = os.listdir(inspiration_path) + types = [x for x in types if x != "system"] + flag = len(types) > 0 + if not flag: + os.mkdir(inspiration_path) + gr.HTML(""" +
" + """) + return inspiration + if not os.path.exists(inspiration_system_path): + os.mkdir(inspiration_system_path) + gallery, names = get_inspiration_images("Exclude abandoned", types[0]) + with gr.Row(): + with gr.Column(scale=2): + inspiration_gallery = gr.Gallery(gallery, show_label=False, elem_id="inspiration_gallery").style(grid=5, height='auto') + with gr.Column(scale=1): + types = gr.Dropdown(choices=types, value=types[0], label="Type", visible=len(types) > 1) + with gr.Row(): + source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") + get_inspiration = gr.Button("Get inspiration") + name = gr.Textbox(show_label=False, interactive=False) + with gr.Row(): + send_to_txt2img = gr.Button('to txt2img') + send_to_img2img = gr.Button('to img2img') + style_gallery = gr.Gallery(show_label=False, elem_id="inspiration_style_gallery").style(grid=2, height='auto') + + collect = gr.Button('Collect') + give_up = gr.Button("Don't show any more") + moveout = gr.Button("Move out", visible=False) + with gr.Row(): + select_button = gr.Button('set button', elem_id="inspiration_select_button") + name_list = gr.State(names) + source.change(source_change, inputs=[source], outputs=[moveout]) + get_inspiration.click(get_inspiration_images, inputs=[source, types], outputs=[inspiration_gallery, name_list]) + select_button.click(select_click, _js="inspiration_selected", inputs=[name, types, name_list], outputs=[name, style_gallery]) + give_up.click(give_up_click, inputs=[name, types], outputs=None) + collect.click(collect_click, inputs=[name, types], outputs=None) + return inspiration diff --git a/modules/shared.py b/modules/shared.py index faede821..ae033710 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -78,6 +78,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") cmd_opts = parser.parse_args() restricted_opts = [ diff --git a/modules/ui.py b/modules/ui.py index a2dbd41e..6a0a3c3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -41,7 +41,8 @@ from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.images_history as img_his +import modules.images_history as images_history +import modules.inspiration as inspiration # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -1082,9 +1083,9 @@ def create_ui(wrap_gradio_gpu_call): upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) - + with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers] , value=shared.sd_upscalers[0].name, type="index") with gr.Group(): extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") @@ -1178,7 +1179,8 @@ def create_ui(wrap_gradio_gpu_call): "i2i":img2img_paste_fields } - images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + browser_interface = images_history.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + inspiration_interface = inspiration.ui(gr, opts) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): @@ -1595,7 +1597,8 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (images_history, "History", "images_history"), + (browser_interface, "History", "images_history"), + (inspiration_interface, "Inspiration", "inspiration"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), (settings_interface, "Settings", "settings"), -- cgit v1.2.3 From 108be15500aac590b4e00420635d7b61fccfa530 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Fri, 21 Oct 2022 01:00:41 +0900 Subject: fix bugs and optimizations --- modules/hypernetworks/hypernetwork.py | 105 +++++++++++++++++++--------------- 1 file changed, 59 insertions(+), 46 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 905cbeef..893ba110 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # if skip_first_layer because first parameters potentially contain negative values # if i < 1: continue + if add_layer_norm: + linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) if activation_func in HypernetworkModule.activation_dict: linears.append(HypernetworkModule.activation_dict[activation_func]()) else: print("Invalid key {} encountered as activation function!".format(activation_func)) # if use_dropout: # linears.append(torch.nn.Dropout(p=0.3)) - if add_layer_norm: - linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) self.linear = torch.nn.Sequential(*linears) @@ -115,11 +115,24 @@ class Hypernetwork: for k, layers in self.layers.items(): for layer in layers: - layer.train() res += layer.trainables() return res + def eval(self): + for k, layers in self.layers.items(): + for layer in layers: + layer.eval() + for items in self.weights(): + items.requires_grad = False + + def train(self): + for k, layers in self.layers.items(): + for layer in layers: + layer.train() + for items in self.weights(): + items.requires_grad = True + def save(self, filename): state_dict = {} @@ -290,10 +303,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.sd_model.first_stage_model.to(devices.cpu) hypernetwork = shared.loaded_hypernetwork - weights = hypernetwork.weights() - for weight in weights: - weight.requires_grad = True - losses = torch.zeros((32,)) last_saved_file = "" @@ -304,10 +313,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... - optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) + optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) + hypernetwork.train() for i, entries in pbar: hypernetwork.step = i + ititial_step @@ -328,8 +337,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True) loss.backward() + del loss optimizer.step() mean_loss = losses.mean() if torch.isnan(mean_loss): @@ -346,44 +356,47 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: + torch.cuda.empty_cache() last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + with torch.no_grad(): + hypernetwork.eval() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) - optimizer.zero_grad() - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - ) - - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - - preview_text = p.prompt - - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None - - if unload: - shared.sd_model.cond_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.cpu) - - if image is not None: - shared.state.current_image = image - image.save(last_saved_image) - last_saved_image += f", prompt: {preview_text}" + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + + preview_text = p.prompt + + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images)>0 else None + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + + if image is not None: + shared.state.current_image = image + image.save(last_saved_image) + last_saved_image += f", prompt: {preview_text}" + + hypernetwork.train() shared.state.job_no = hypernetwork.step -- cgit v1.2.3 From f89829ec3a0baceb445451ad98d4fb4323e922aa Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 21 Oct 2022 01:37:11 +0900 Subject: Revert "fix bugs and optimizations" This reverts commit 108be15500aac590b4e00420635d7b61fccfa530. --- modules/hypernetworks/hypernetwork.py | 105 +++++++++++++++------------------- 1 file changed, 46 insertions(+), 59 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 893ba110..905cbeef 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # if skip_first_layer because first parameters potentially contain negative values # if i < 1: continue - if add_layer_norm: - linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) if activation_func in HypernetworkModule.activation_dict: linears.append(HypernetworkModule.activation_dict[activation_func]()) else: print("Invalid key {} encountered as activation function!".format(activation_func)) # if use_dropout: # linears.append(torch.nn.Dropout(p=0.3)) + if add_layer_norm: + linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) self.linear = torch.nn.Sequential(*linears) @@ -115,24 +115,11 @@ class Hypernetwork: for k, layers in self.layers.items(): for layer in layers: + layer.train() res += layer.trainables() return res - def eval(self): - for k, layers in self.layers.items(): - for layer in layers: - layer.eval() - for items in self.weights(): - items.requires_grad = False - - def train(self): - for k, layers in self.layers.items(): - for layer in layers: - layer.train() - for items in self.weights(): - items.requires_grad = True - def save(self, filename): state_dict = {} @@ -303,6 +290,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.sd_model.first_stage_model.to(devices.cpu) hypernetwork = shared.loaded_hypernetwork + weights = hypernetwork.weights() + for weight in weights: + weight.requires_grad = True + losses = torch.zeros((32,)) last_saved_file = "" @@ -313,10 +304,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate) + # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... + optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - hypernetwork.train() for i, entries in pbar: hypernetwork.step = i + ititial_step @@ -337,9 +328,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item() - optimizer.zero_grad(set_to_none=True) + optimizer.zero_grad() loss.backward() - del loss optimizer.step() mean_loss = losses.mean() if torch.isnan(mean_loss): @@ -356,47 +346,44 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: - torch.cuda.empty_cache() last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - with torch.no_grad(): - hypernetwork.eval() - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - ) - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - - preview_text = p.prompt - - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None - - if unload: - shared.sd_model.cond_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.cpu) - - if image is not None: - shared.state.current_image = image - image.save(last_saved_image) - last_saved_image += f", prompt: {preview_text}" - - hypernetwork.train() + optimizer.zero_grad() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + + preview_text = p.prompt + + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images)>0 else None + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + + if image is not None: + shared.state.current_image = image + image.save(last_saved_image) + last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step -- cgit v1.2.3 From 92a17a7a4a13fceb3c3e25a2e854b2a7dd6eb5df Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 20 Oct 2022 09:45:03 -0700 Subject: Made dummy latents smaller. Minor code cleanups --- modules/processing.py | 7 ++++--- modules/sd_samplers.py | 6 ++++-- 2 files changed, 8 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 3caac25e..539cde38 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -557,7 +557,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: # Dummy zero conditioning if we're not using inpainting model. # Still takes up a bit of memory, but no encoder call. - image_conditioning = torch.zeros(x.shape[0], 5, x.shape[-2], x.shape[-1], dtype=x.dtype, device=x.device) + # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. + image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) return image_conditioning @@ -759,8 +760,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) else: self.image_conditioning = torch.zeros( - self.init_latent.shape[0], 5, self.init_latent.shape[-2], self.init_latent.shape[-1], - dtype=self.init_latent.dtype, + self.init_latent.shape[0], 5, 1, 1, + dtype=self.init_latent.dtype, device=self.init_latent.device ) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index c21be26e..cc682593 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -138,7 +138,7 @@ class VanillaStableDiffusionSampler: if self.stop_at is not None and self.step > self.stop_at: raise InterruptedException - # Have to unwrap the inpainting conditioning here to perform pre-preocessing + # Have to unwrap the inpainting conditioning here to perform pre-processing image_conditioning = None if isinstance(cond, dict): image_conditioning = cond["c_concat"][0] @@ -146,7 +146,7 @@ class VanillaStableDiffusionSampler: unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' cond = tensor @@ -165,6 +165,8 @@ class VanillaStableDiffusionSampler: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec + # Wrap the image conditioning back up since the DDIM code can accept the dict directly. + # Note that they need to be lists because it just concatenates them later. if image_conditioning is not None: cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} -- cgit v1.2.3 From d1cb08bfb221cd1b0cfc6078162b4e206ea80a5c Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Thu, 20 Oct 2022 22:49:06 +0300 Subject: fix skip and interrupt for highres. fix option --- modules/processing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index bcb0c32c..6324ca91 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -587,9 +587,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) - - return samples + return self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) or samples class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): -- cgit v1.2.3 From 708c3a7bd8ce68cbe1aa7c268e5a4b1980affc9f Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 20 Oct 2022 13:28:43 -0700 Subject: Added PLMS hijack and made sure to always replace methods --- modules/sd_hijack_inpainting.py | 163 ++++++++++++++++++++++++++++++++++++++-- modules/sd_models.py | 3 +- 2 files changed, 157 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index d4d28d2e..43938071 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -1,16 +1,14 @@ import torch -import numpy as np -from tqdm import tqdm -from einops import rearrange, repeat +from einops import repeat from omegaconf import ListConfig -from types import MethodType - import ldm.models.diffusion.ddpm import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.ddim import DDIMSampler, noise_like # ================================================================================================= @@ -19,7 +17,7 @@ from ldm.models.diffusion.ddim import DDIMSampler, noise_like # https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py # ================================================================================================= @torch.no_grad() -def sample(self, +def sample_ddim(self, S, batch_size, shape, @@ -132,6 +130,153 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0 +# ================================================================================================= +# Monkey patch PLMSSampler methods. +# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes. +# Adapted from: +# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py +# ================================================================================================= +@torch.no_grad() +def sample_plms(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + +@torch.no_grad() +def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t + # ================================================================================================= # Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config. # Adapted from: @@ -175,5 +320,9 @@ def should_hijack_inpainting(checkpoint_info): def do_inpainting_hijack(): ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion + ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim - ldm.models.diffusion.ddim.DDIMSampler.sample = sample \ No newline at end of file + ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim + + ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms + ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms \ No newline at end of file diff --git a/modules/sd_models.py b/modules/sd_models.py index 47836d25..7072db08 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -214,8 +214,6 @@ def load_model(): sd_config = OmegaConf.load(checkpoint_info.config) if should_hijack_inpainting(checkpoint_info): - do_inpainting_hijack() - # Hardcoded config for now... sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" sd_config.model.params.use_ema = False @@ -225,6 +223,7 @@ def load_model(): # Create a "fake" config with a different name so that we know to unload it when switching models. checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) -- cgit v1.2.3 From d23a46ceaa76af2847f11172f32c92665c268b1b Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Thu, 20 Oct 2022 23:49:14 +0300 Subject: Different approach to skip/interrupt with highres fix --- modules/processing.py | 4 +++- modules/sd_samplers.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 6324ca91..bcb0c32c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -587,7 +587,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - return self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) or samples + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) + + return samples class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index b58e810b..7ff77c01 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -196,6 +196,7 @@ class VanillaStableDiffusionSampler: x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) self.init_latent = x + self.last_latent = x self.step = 0 samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) @@ -206,6 +207,7 @@ class VanillaStableDiffusionSampler: self.initialize(p) self.init_latent = None + self.last_latent = x self.step = 0 steps = steps or p.steps @@ -388,6 +390,7 @@ class KDiffusionSampler: extra_params_kwargs['sigmas'] = sigma_sched self.model_wrap_cfg.init_latent = x + self.last_latent = x samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) @@ -414,6 +417,7 @@ class KDiffusionSampler: else: extra_params_kwargs['sigmas'] = sigmas + self.last_latent = x samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples -- cgit v1.2.3 From 49533eed9e3aad19e9868ee140708baec4fd44be Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 20 Oct 2022 16:01:27 -0700 Subject: XY grid correctly re-assignes model when config changes --- modules/sd_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072db08..fea84630 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -204,9 +204,9 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(): +def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack - checkpoint_info = select_checkpoint() + checkpoint_info = checkpoint_info or select_checkpoint() if checkpoint_info.config != shared.cmd_opts.config: print(f"Loading config from: {checkpoint_info.config}") @@ -249,7 +249,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model() + shared.sd_model = load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: -- cgit v1.2.3 From a3b047b7c74dc6ca07f40aee778997fc1889d72f Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Thu, 20 Oct 2022 19:28:58 -0500 Subject: add settings option to toggle button visibility --- modules/shared.py | 1 + modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index faede821..7e9c2696 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -300,6 +300,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + "trash_prompt_visible": OptionInfo(True, "Show trash prompt button"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index bde546cc..13c0b4ca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -509,7 +509,7 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt") + trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt", visible=opts.trash_prompt_visible) token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") -- cgit v1.2.3 From 45872181902ada06267e2de601586d512cf5df1a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 09:00:39 +0300 Subject: updated readme and some small stylistic changes to code --- modules/processing.py | 14 ++++++-------- modules/sd_hijack_inpainting.py | 3 +++ 2 files changed, 9 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 539cde38..21786968 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -540,11 +540,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - - def create_dummy_mask(self, x, first_phase: bool = False): + def create_dummy_mask(self, x, width=None, height=None): if self.sampler.conditioning_key in {'hybrid', 'concat'}: - height = self.firstphase_height if first_phase else self.height - width = self.firstphase_width if first_phase else self.width + height = height or self.height + width = width or self.width # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) @@ -571,7 +570,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, first_phase=True)) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height)) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] @@ -634,6 +633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.inpainting_mask_invert = inpainting_mask_invert self.mask = None self.nmask = None + self.image_conditioning = None def init(self, all_prompts, all_seeds, all_subseeds): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) @@ -735,9 +735,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - conditioning_key = self.sampler.conditioning_key - - if conditioning_key in {'hybrid', 'concat'}: + if self.sampler.conditioning_key in {'hybrid', 'concat'}: if self.image_mask is not None: conditioning_mask = np.array(self.image_mask.convert("L")) conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 43938071..fd92a335 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -301,6 +301,7 @@ def get_unconditional_conditioning(self, batch_size, null_label=None): c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) return c + class LatentInpaintDiffusion(LatentDiffusion): def __init__( self, @@ -314,9 +315,11 @@ class LatentInpaintDiffusion(LatentDiffusion): assert self.masked_image_key in concat_keys self.concat_keys = concat_keys + def should_hijack_inpainting(checkpoint_info): return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml") + def do_inpainting_hijack(): ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion -- cgit v1.2.3 From 74088c2a06a975092806362aede22f82716cb011 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 20 Oct 2022 08:18:02 +0300 Subject: allow float sizes for hypernet's layer_structure --- modules/hypernetworks/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 08f75f15..e0741d08 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -15,7 +15,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm assert not os.path.exists(fn), f"file {fn} already exists" if type(layer_structure) == str: - layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure))) + layer_structure = [float(x.strip()) for x in layer_structure.split(",")] hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( name=name, -- cgit v1.2.3 From 60872c5b404114336f9ca0c671ba88fa4a8201c9 Mon Sep 17 00:00:00 2001 From: winterspringsummer Date: Thu, 20 Oct 2022 19:10:32 +0900 Subject: Fixed path issue while extras batch processing --- modules/extras.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index b853fa5b..f9796624 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -118,10 +118,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ while len(cached_images) > 2: del cached_images[next(iter(cached_images.keys()))] - - images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, - forced_filename=image_name if opts.use_original_name_batch else None) + + if opts.use_original_name_batch and image_name != None: + basename = os.path.splitext(os.path.basename(image_name))[0] + else: + basename = '' + + images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) if opts.enable_pnginfo: image.info = existing_pnginfo -- cgit v1.2.3 From fb5a8cf0d9ed027ea3aa2e5422c946d8e6e72efe Mon Sep 17 00:00:00 2001 From: winterspringsummer Date: Thu, 20 Oct 2022 21:31:29 +0900 Subject: Added try except to extras batch from directory --- modules/extras.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index f9796624..0d817cf9 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -41,7 +41,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ return outputs, "Please select an input directory.", '' image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] for img in image_list: - image = Image.open(img) + try: + image = Image.open(img) + except Exception: + continue imageArr.append(image) imageNameArr.append(img) else: @@ -122,10 +125,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ if opts.use_original_name_batch and image_name != None: basename = os.path.splitext(os.path.basename(image_name))[0] else: - basename = '' + basename = None - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) + images.save_image(image, path=outpath, basename='', seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=basename) if opts.enable_pnginfo: image.info = existing_pnginfo -- cgit v1.2.3 From a13c3bed3cec27afe3c015d3d62db36e25b10d1f Mon Sep 17 00:00:00 2001 From: winterspringsummer Date: Thu, 20 Oct 2022 21:43:27 +0900 Subject: Fixed path issue while extras batch processing --- modules/extras.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 0d817cf9..ac85142c 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -125,10 +125,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ if opts.use_original_name_batch and image_name != None: basename = os.path.splitext(os.path.basename(image_name))[0] else: - basename = None + basename = '' - images.save_image(image, path=outpath, basename='', seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=basename) + images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) if opts.enable_pnginfo: image.info = existing_pnginfo -- cgit v1.2.3 From 9d71eef02e7395e179b8d5e61e6d91ddd8928d2e Mon Sep 17 00:00:00 2001 From: winterspringsummer Date: Fri, 21 Oct 2022 09:23:13 +0900 Subject: sort file list in alphabetical ordering in extras --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index ac85142c..22c5a1c1 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -39,7 +39,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ if input_dir == '': return outputs, "Please select an input directory.", '' - image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] + image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)] for img in image_list: try: image = Image.open(img) -- cgit v1.2.3 From c23f666dba2b484d521d2dc4be91cf9e09312647 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 09:47:43 +0300 Subject: a more strict check for activation type and a more reasonable check for type of layer in hypernets --- modules/hypernetworks/hypernetwork.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d617680..84e7e350 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -32,10 +32,16 @@ class HypernetworkModule(torch.nn.Module): linears = [] for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + if activation_func == "relu": linears.append(torch.nn.ReLU()) - if activation_func == "leakyrelu": + elif activation_func == "leakyrelu": linears.append(torch.nn.LeakyReLU()) + elif activation_func == 'linear' or activation_func is None: + pass + else: + raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') + if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - if not "ReLU" in layer.__str__(): + if type(layer) == torch.nn.Linear: layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() @@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - if not "ReLU" in layer.__str__(): + if type(layer) == torch.nn.Linear: layer_structure += [layer.weight, layer.bias] return layer_structure -- cgit v1.2.3 From 7157e5d064741fa57ca81a2c6432a651f21ee82f Mon Sep 17 00:00:00 2001 From: Patryk Wychowaniec Date: Thu, 20 Oct 2022 19:22:59 +0200 Subject: interrogate: Fix CLIP-interrogation on CPU Currently, trying to perform CLIP interrogation on a CPU fails, saying: ``` RuntimeError: "slow_conv2d_cpu" not implemented for 'Half' ``` This merge request fixes this issue by detecting whether the target device is CPU and, if so, force-enabling `--no-half` and passing `device="cpu"` to `clip.load()` (which then does some extra tricks to ensure it works correctly on CPU). --- modules/interrogate.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index 64b91eb4..65b05d34 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -28,9 +28,11 @@ class InterrogateModels: clip_preprocess = None categories = None dtype = None + running_on_cpu = None def __init__(self, content_dir): self.categories = [] + self.running_on_cpu = devices.device_interrogate == torch.device("cpu") if os.path.exists(content_dir): for filename in os.listdir(content_dir): @@ -53,7 +55,11 @@ class InterrogateModels: def load_clip_model(self): import clip - model, preprocess = clip.load(clip_model_name) + if self.running_on_cpu: + model, preprocess = clip.load(clip_model_name, device="cpu") + else: + model, preprocess = clip.load(clip_model_name) + model.eval() model = model.to(devices.device_interrogate) @@ -62,14 +68,14 @@ class InterrogateModels: def load(self): if self.blip_model is None: self.blip_model = self.load_blip_model() - if not shared.cmd_opts.no_half: + if not shared.cmd_opts.no_half and not self.running_on_cpu: self.blip_model = self.blip_model.half() self.blip_model = self.blip_model.to(devices.device_interrogate) if self.clip_model is None: self.clip_model, self.clip_preprocess = self.load_clip_model() - if not shared.cmd_opts.no_half: + if not shared.cmd_opts.no_half and not self.running_on_cpu: self.clip_model = self.clip_model.half() self.clip_model = self.clip_model.to(devices.device_interrogate) -- cgit v1.2.3 From b69c37d25e4ffc56e8f8c247fa2c38b4648cefb7 Mon Sep 17 00:00:00 2001 From: guaneec Date: Thu, 20 Oct 2022 22:21:12 +0800 Subject: Allow datasets with only 1 image in TI --- modules/textual_inversion/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 23bb4b6a..5b1c5002 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -83,7 +83,7 @@ class PersonalizedBase(Dataset): self.dataset.append(entry) - assert len(self.dataset) > 1, "No images have been found in the dataset." + assert len(self.dataset) > 0, "No images have been found in the dataset." self.length = len(self.dataset) * repeats // batch_size self.initial_indexes = np.arange(len(self.dataset)) @@ -91,7 +91,7 @@ class PersonalizedBase(Dataset): self.shuffle() def shuffle(self): - self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()] def create_text(self, filename_text): text = random.choice(self.lines) -- cgit v1.2.3 From 5245c7a4935f67b677da0f5a1fc2b74c074aa0e2 Mon Sep 17 00:00:00 2001 From: timntorres Date: Wed, 19 Oct 2022 12:21:32 -0700 Subject: Issue #2921-Give PNG info to Hypernet previews. --- modules/hypernetworks/hypernetwork.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 84e7e350..68c8f26d 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -256,6 +256,9 @@ def stack_conds(conds): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + # images is required here to give training previews their infotext. Importing this at the very top causes a circular dependency. + from modules import images + assert hypernetwork_name, 'hypernetwork not selected' path = shared.hypernetworks.get(hypernetwork_name, None) @@ -298,6 +301,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log last_saved_file = "" last_saved_image = "" + forced_filename = "" ititial_step = hypernetwork.step or 0 if ititial_step > steps: @@ -345,7 +349,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: - last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + forced_filename = f'{hypernetwork_name}-{hypernetwork.step}' + last_saved_image = os.path.join(images_dir, forced_filename) optimizer.zero_grad() shared.sd_model.cond_stage_model.to(devices.device) @@ -381,7 +386,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if image is not None: shared.state.current_image = image - image.save(last_saved_image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step -- cgit v1.2.3 From 6014fb8afbe05c8d02fffe7a36a2e48128713bd2 Mon Sep 17 00:00:00 2001 From: timntorres Date: Wed, 19 Oct 2022 12:22:23 -0700 Subject: Do nothing if image file already exists. --- modules/images.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index b9589563..550e53ae 100644 --- a/modules/images.py +++ b/modules/images.py @@ -416,7 +416,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /') path = os.path.join(path, dirname) - os.makedirs(path, exist_ok=True) + try: + os.makedirs(path, exist_ok=True) + except FileExistsError: + # If the file already exists, continue and allow said file to be overwritten. + pass if forced_filename is None: basecount = get_next_sequence_number(path, basename) -- cgit v1.2.3 From 4ff274e1e35bb642687253ce744d2cfa738ab293 Mon Sep 17 00:00:00 2001 From: timntorres Date: Wed, 19 Oct 2022 12:32:22 -0700 Subject: Revise comments. --- modules/hypernetworks/hypernetwork.py | 2 +- modules/images.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 68c8f26d..3f96361c 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -256,7 +256,7 @@ def stack_conds(conds): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # images is required here to give training previews their infotext. Importing this at the very top causes a circular dependency. + # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images assert hypernetwork_name, 'hypernetwork not selected' diff --git a/modules/images.py b/modules/images.py index 550e53ae..b8834e3c 100644 --- a/modules/images.py +++ b/modules/images.py @@ -419,7 +419,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i try: os.makedirs(path, exist_ok=True) except FileExistsError: - # If the file already exists, continue and allow said file to be overwritten. + # If the file already exists, allow said file to be overwritten. pass if forced_filename is None: -- cgit v1.2.3 From 2273e752fb3e578f1047f6d38b96330b07bf61a9 Mon Sep 17 00:00:00 2001 From: timntorres Date: Wed, 19 Oct 2022 14:23:48 -0700 Subject: Remove redundant try/except. --- modules/images.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index b8834e3c..b9589563 100644 --- a/modules/images.py +++ b/modules/images.py @@ -416,11 +416,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /') path = os.path.join(path, dirname) - try: - os.makedirs(path, exist_ok=True) - except FileExistsError: - # If the file already exists, allow said file to be overwritten. - pass + os.makedirs(path, exist_ok=True) if forced_filename is None: basecount = get_next_sequence_number(path, basename) -- cgit v1.2.3 From 03a1e288c4973dd2dff57a97469b40f146b6fccf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 10:13:24 +0300 Subject: turns out LayerNorm also has weight and bias and needs to be pre-multiplied and trained for hypernets --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3274a802..b1a5d0c7 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - if type(layer) == torch.nn.Linear: + if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() @@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - if type(layer) == torch.nn.Linear: + if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: layer_structure += [layer.weight, layer.bias] return layer_structure -- cgit v1.2.3 From bf30673f5132c8f28357b31224c54331e788d3e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 10:19:25 +0300 Subject: Fix Hypernet infotext string split bug for PR #3283 --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 21786968..d1deffa9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]), + "Hypernet": (None if shared.loaded_hypernetwork is None else os.path.splitext(os.path.basename(shared.loaded_hypernetwork.filename))[0]), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), -- cgit v1.2.3 From df5706409386cc2e88718bd9101045587c39f8bb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 16:10:51 +0300 Subject: do not load aesthetic clip model until it's needed add refresh button for aesthetic embeddings add aesthetic params to images' infotext --- modules/aesthetic_clip.py | 40 +++++++++++++++++++---- modules/generation_parameters_copypaste.py | 18 +++++++++-- modules/img2img.py | 5 +-- modules/processing.py | 4 +-- modules/sd_models.py | 3 -- modules/txt2img.py | 4 +-- modules/ui.py | 52 ++++++++++++++++++++---------- 7 files changed, 88 insertions(+), 38 deletions(-) (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index 34efa931..8c828541 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -40,6 +40,8 @@ def iter_to_batched(iterable, n=1): def create_ui(): + import modules.ui + with gr.Group(): with gr.Accordion("Open for Clip Aesthetic!", open=False): with gr.Row(): @@ -55,6 +57,8 @@ def create_ui(): label="Aesthetic imgs embedding", value="None") + modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings") + with gr.Row(): aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", @@ -66,11 +70,21 @@ def create_ui(): return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative +aesthetic_clip_model = None + + +def aesthetic_clip(): + global aesthetic_clip_model + + if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path: + aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path) + aesthetic_clip_model.cpu() + + return aesthetic_clip_model + + def generate_imgs_embd(name, folder, batch_size): - # clipModel = CLIPModel.from_pretrained( - # shared.sd_model.cond_stage_model.clipModel.name_or_path - # ) - model = shared.clip_model.to(device) + model = aesthetic_clip().to(device) processor = CLIPProcessor.from_pretrained(model.name_or_path) with torch.no_grad(): @@ -91,7 +105,7 @@ def generate_imgs_embd(name, folder, batch_size): path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") torch.save(embs, path) - model = model.cpu() + model.cpu() del processor del embs gc.collect() @@ -132,7 +146,7 @@ class AestheticCLIP: self.image_embs = None self.load_image_embs(None) - def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, + def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, aesthetic_slerp=True, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False): @@ -145,6 +159,18 @@ class AestheticCLIP: self.aesthetic_steps = aesthetic_steps self.load_image_embs(image_embs_name) + if self.image_embs_name is not None: + p.extra_generation_params.update({ + "Aesthetic LR": aesthetic_lr, + "Aesthetic weight": aesthetic_weight, + "Aesthetic steps": aesthetic_steps, + "Aesthetic embedding": self.image_embs_name, + "Aesthetic slerp": aesthetic_slerp, + "Aesthetic text": aesthetic_imgs_text, + "Aesthetic text negative": aesthetic_text_negative, + "Aesthetic slerp angle": aesthetic_slerp_angle, + }) + def set_skip(self, skip): self.skip = skip @@ -168,7 +194,7 @@ class AestheticCLIP: tokens = torch.asarray(remade_batch_tokens).to(device) - model = copy.deepcopy(shared.clip_model).to(device) + model = copy.deepcopy(aesthetic_clip()).to(device) model.requires_grad_(True) if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: text_embs_2 = model.get_text_features( diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 0f041449..f73647da 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -4,13 +4,22 @@ import gradio as gr from modules.shared import script_path from modules import shared -re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" +re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") type_of_gr_update = type(gr.update()) +def quote(text): + if ',' not in str(text): + return text + + text = str(text) + text = text.replace('\\', '\\\\') + text = text.replace('"', '\\"') + return f'"{text}"' + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` @@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None): else: try: valtype = type(output.value) - val = valtype(v) + + if valtype == bool and v == "False": + val = False + else: + val = valtype(v) + res.append(gr.update(value=val)) except Exception: res.append(gr.update()) diff --git a/modules/img2img.py b/modules/img2img.py index bc7c66bc..eea5199b 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -109,10 +109,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpainting_mask_invert=inpainting_mask_invert, ) - shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), - aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative) + shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index d1deffa9..f0852cd5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -318,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params.update(p.extra_generation_params) - generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" diff --git a/modules/sd_models.py b/modules/sd_models.py index 05a1df28..b1c91b0d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -234,9 +234,6 @@ def load_model(checkpoint_info=None): sd_hijack.model_hijack.hijack(sd_model) - if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path: - shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path) - sd_model.eval() print(f"Model loaded.") diff --git a/modules/txt2img.py b/modules/txt2img.py index 32ed1d8d..1761cfa2 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -36,9 +36,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: firstphase_height=firstphase_height if enable_hr else None, ) - shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), - aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, - aesthetic_text_negative) + shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/ui.py b/modules/ui.py index 381ca925..0d020de6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -597,27 +597,29 @@ def apply_setting(key, value): return value -def create_ui(wrap_gradio_gpu_call): - import modules.img2img - import modules.txt2img +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args - def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args + for k, v in args.items(): + setattr(refresh_component, k, v) - for k, v in args.items(): - setattr(refresh_component, k, v) + return gr.update(**(args or {})) - return gr.update(**(args or {})) + refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_ui(wrap_gradio_gpu_call): + import modules.img2img + import modules.txt2img - refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn = refresh, - inputs = [], - outputs = [refresh_component] - ) - return refresh_button with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) @@ -802,6 +804,14 @@ def create_ui(wrap_gradio_gpu_call): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), + (aesthetic_lr, "Aesthetic LR"), + (aesthetic_weight, "Aesthetic weight"), + (aesthetic_steps, "Aesthetic steps"), + (aesthetic_imgs, "Aesthetic embedding"), + (aesthetic_slerp, "Aesthetic slerp"), + (aesthetic_imgs_text, "Aesthetic text"), + (aesthetic_text_negative, "Aesthetic text negative"), + (aesthetic_slerp_angle, "Aesthetic slerp angle"), ] txt2img_preview_params = [ @@ -1077,6 +1087,14 @@ def create_ui(wrap_gradio_gpu_call): (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), (denoising_strength, "Denoising strength"), + (aesthetic_lr_im, "Aesthetic LR"), + (aesthetic_weight_im, "Aesthetic weight"), + (aesthetic_steps_im, "Aesthetic steps"), + (aesthetic_imgs_im, "Aesthetic embedding"), + (aesthetic_slerp_im, "Aesthetic slerp"), + (aesthetic_imgs_text_im, "Aesthetic text"), + (aesthetic_text_negative_im, "Aesthetic text negative"), + (aesthetic_slerp_angle_im, "Aesthetic slerp angle"), ] token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) -- cgit v1.2.3 From 9286fe53de2eef91f13cc3ad5938ddf67ecc8413 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 16:38:06 +0300 Subject: make aestetic embedding ciompatible with prompts longer than 75 tokens --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 36198a3c..1f8587d1 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -332,8 +332,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers.append([1.0] * 75) z1 = self.process_tokens(tokens, multipliers) + z1 = shared.aesthetic_clip(z1, remade_batch_tokens) z = z1 if z is None else torch.cat((z, z1), axis=-2) - z = shared.aesthetic_clip(z, remade_batch_tokens) remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers -- cgit v1.2.3 From d0ea471b0cdaede163c6e7f6fae8535f5c3cd226 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 21 Oct 2022 14:04:41 +0100 Subject: Use opts in textual_inversion image_embedding.py for dynamic fonts --- modules/textual_inversion/image_embedding.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index 898ce3b3..c50b1e7b 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -5,6 +5,7 @@ import zlib from PIL import Image, PngImagePlugin, ImageDraw, ImageFont from fonts.ttf import Roboto import torch +from modules.shared import opts class EmbeddingEncoder(json.JSONEncoder): -- cgit v1.2.3 From 306e2ff6ab8f4c7e94ab55f4f08ab8f94d73d287 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 21 Oct 2022 14:47:21 +0100 Subject: Update image_embedding.py --- modules/textual_inversion/image_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index c50b1e7b..ea653806 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -134,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t from math import cos image = srcimage.copy() - + fontsize = 32 if textfont is None: try: textfont = ImageFont.truetype(opts.font or Roboto, fontsize) @@ -151,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) draw = ImageDraw.Draw(image) - fontsize = 32 + font = ImageFont.truetype(textfont, fontsize) padding = 10 -- cgit v1.2.3 From 51e3dc9ccad157d7161b697a246e26c868d46a7c Mon Sep 17 00:00:00 2001 From: timntorres Date: Fri, 21 Oct 2022 02:11:12 -0700 Subject: Sanitize hypernet name input. --- modules/hypernetworks/ui.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 266f04f6..e6f50a1f 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -11,6 +11,9 @@ from modules.hypernetworks import hypernetwork def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None): + # Remove illegal characters from name. + name = "".join( x for x in name if (x.isalnum() or x in "._- ")) + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") if not overwrite_old: assert not os.path.exists(fn), f"file {fn} already exists" -- cgit v1.2.3 From 19818f023cfafc472c6c241cab0b72896a168481 Mon Sep 17 00:00:00 2001 From: timntorres Date: Fri, 21 Oct 2022 02:14:02 -0700 Subject: Match hypernet name with filename in all cases. --- modules/hypernetworks/hypernetwork.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b1a5d0c7..6d392be4 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -340,7 +340,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log pbar.set_description(f"loss: {mean_loss:.7f}") if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') + temp = hypernetwork.name + # Before saving, change name to match current checkpoint. + hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') hypernetwork.save(last_saved_file) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { @@ -405,6 +408,9 @@ Last saved image: {html.escape(last_saved_image)}
hypernetwork.sd_checkpoint = checkpoint.hash hypernetwork.sd_checkpoint_name = checkpoint.model_name + # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention). + hypernetwork.name = hypernetwork_name + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt') hypernetwork.save(filename) return hypernetwork, filename -- cgit v1.2.3 From fccad18a59e3c2c33fefbbb1763c6a87a3a68eba Mon Sep 17 00:00:00 2001 From: timntorres Date: Fri, 21 Oct 2022 02:17:26 -0700 Subject: Refer to Hypernet's name, sensibly, by its name variable. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index f0852cd5..ff1ec4c9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else os.path.splitext(os.path.basename(shared.loaded_hypernetwork.filename))[0]), + "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), -- cgit v1.2.3 From 272fa527bbe93143668ffc16838107b7dca35b40 Mon Sep 17 00:00:00 2001 From: timntorres Date: Fri, 21 Oct 2022 02:41:55 -0700 Subject: Remove unused variable. --- modules/hypernetworks/hypernetwork.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 6d392be4..47d91ea5 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -340,7 +340,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log pbar.set_description(f"loss: {mean_loss:.7f}") if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: - temp = hypernetwork.name # Before saving, change name to match current checkpoint. hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}' last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') -- cgit v1.2.3 From 02e4d4694dd9254a6ca9f05c2eb7b01ea508abc7 Mon Sep 17 00:00:00 2001 From: Rcmcpe Date: Fri, 21 Oct 2022 15:53:35 +0800 Subject: Change option description of unload_models_when_training --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 5c675b80..41d7f08e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -266,7 +266,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('training', "Training"), { - "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), + "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), -- cgit v1.2.3 From 704036ff07b71bf86cadcbbff2bcfeebdd1ed3a6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 17:11:42 +0300 Subject: make aspect ratio overlay work regardless of selected localization --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 0d020de6..85f95792 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -879,8 +879,8 @@ def create_ui(wrap_gradio_gpu_call): sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") with gr.Group(): - width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height") with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) -- cgit v1.2.3 From ac0aa2b18efeeb9220a5994c8dd54c7cdda7cc40 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 17:35:51 +0300 Subject: loading SD VAE, see PR #3303 --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index b1c91b0d..d99dbce8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -155,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + + def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash @@ -186,7 +189,7 @@ def load_model_weights(model, checkpoint_info): if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.3 From f49c08ea566385db339c6628f65c3a121033f67c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 18:46:02 +0300 Subject: prevent error spam when processing images without txt files for captions --- modules/textual_inversion/preprocess.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 17e4ddc1..33eaddb6 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -122,11 +122,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre continue existing_caption = None - - try: - existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read() - except Exception as e: - print(e) + existing_caption_filename = os.path.splitext(filename)[0] + '.txt' + if os.path.exists(existing_caption_filename): + with open(existing_caption_filename, 'r', encoding="utf8") as file: + existing_caption = file.read() if shared.state.interrupted: break -- cgit v1.2.3 From bb0f1a2cdae3410a41d06ae878f56e29b8154c41 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sat, 22 Oct 2022 01:23:00 +0800 Subject: inspiration finished --- modules/inspiration.py | 192 ++++++++++++++++++++++++++++++++----------------- modules/shared.py | 6 ++ modules/ui.py | 2 +- 3 files changed, 133 insertions(+), 67 deletions(-) (limited to 'modules') diff --git a/modules/inspiration.py b/modules/inspiration.py index 456bfcb5..f72ebf3a 100644 --- a/modules/inspiration.py +++ b/modules/inspiration.py @@ -1,122 +1,182 @@ import os import random -import gradio -inspiration_path = "inspiration" -inspiration_system_path = os.path.join(inspiration_path, "system") -def read_name_list(file): +import gradio +from modules.shared import opts +inspiration_system_path = os.path.join(opts.inspiration_dir, "system") +def read_name_list(file, types=None, keyword=None): if not os.path.exists(file): return [] - f = open(file, "r") ret = [] + f = open(file, "r") line = f.readline() while len(line) > 0: line = line.rstrip("\n") - ret.append(line) - print(ret) + if types is not None: + dirname = os.path.split(line) + if dirname[0] in types and keyword in dirname[1]: + ret.append(line) + else: + ret.append(line) + line = f.readline() return ret def save_name_list(file, name): - print(file) - f = open(file, "a") - f.write(name + "\n") + with open(file, "a") as f: + f.write(name + "\n") -def get_inspiration_images(source, types): - path = os.path.join(inspiration_path , types) +def get_types_list(): + files = os.listdir(opts.inspiration_dir) + types = [] + for x in files: + path = os.path.join(opts.inspiration_dir, x) + if x[0] == ".": + continue + if not os.path.isdir(path): + continue + if path == inspiration_system_path: + continue + types.append(x) + return types + +def get_inspiration_images(source, types, keyword): + get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num) if source == "Favorites": - names = read_name_list(os.path.join(inspiration_system_path, types + "_faverites.txt")) - names = random.sample(names, 25) + names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword) + names = random.sample(names, get_num) if len(names) > get_num else names elif source == "Abandoned": - names = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) - names = random.sample(names, 25) - elif source == "Exclude abandoned": - abondened = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) - all_names = os.listdir(path) - names = [] - while len(names) < 25: - name = random.choice(all_names) - if name not in abondened: - names.append(name) + names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) + print(names) + names = random.sample(names, get_num) if len(names) > get_num else names + elif source == "Exclude abandoned": + abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) + all_names = [] + for tp in types: + name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) + all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + + if len(all_names) > get_num: + names = [] + while len(names) < get_num: + name = random.choice(all_names) + if name not in abandoned: + names.append(name) + else: + names = all_names else: - names = random.sample(os.listdir(path), 25) - names = random.sample(names, 25) + all_names = [] + for tp in types: + name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) + all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names image_list = [] for a in names: - image_path = os.path.join(path, a) + image_path = os.path.join(opts.inspiration_dir, a) images = os.listdir(image_path) - image_list.append(os.path.join(image_path, random.choice(images))) - return image_list, names + image_list.append((os.path.join(image_path, random.choice(images)), a)) + return image_list, names, "" -def select_click(index, types, name_list): +def select_click(index, name_list): name = name_list[int(index)] - path = os.path.join(inspiration_path, types, name) + path = os.path.join(opts.inspiration_dir, name) images = os.listdir(path) - return name, [os.path.join(path, x) for x in images] + return name, [os.path.join(path, x) for x in images], "" -def give_up_click(name, types): - file = os.path.join(inspiration_system_path, types + "_abandoned.txt") +def give_up_click(name): + file = os.path.join(inspiration_system_path, "abandoned.txt") name_list = read_name_list(file) if name not in name_list: save_name_list(file, name) + return "Added to abandoned list" -def collect_click(name, types): - file = os.path.join(inspiration_system_path, types + "_faverites.txt") - print(file) +def collect_click(name): + file = os.path.join(inspiration_system_path, "faverites.txt") name_list = read_name_list(file) - print(name_list) if name not in name_list: save_name_list(file, name) + return "Added to faverite list" -def moveout_click(name, types): - file = os.path.join(inspiration_system_path, types + "_faverites.txt") +def moveout_click(name, source): + if source == "Abandoned": + file = os.path.join(inspiration_system_path, "abandoned.txt") + if source == "Favorites": + file = os.path.join(inspiration_system_path, "faverites.txt") + else: + return None name_list = read_name_list(file) - if name not in name_list: - save_name_list(file, name) + os.remove(file) + with open(file, "a") as f: + for a in name_list: + if a != name: + f.write(a) + return "Moved out {name} from {source} list" def source_change(source): - if source == "Abandoned" or source == "Favorites": - return gradio.Button.update(visible=True, value=f"Move out {source}") + if source in ["Abandoned", "Favorites"]: + return gradio.update(visible=True), [] else: - return gradio.Button.update(visible=False) + return gradio.update(visible=False), [] +def add_to_prompt(name, prompt): + print(name, prompt) + name = os.path.basename(name) + return prompt + "," + name -def ui(gr, opts): +def ui(gr, opts, txt2img_prompt, img2img_prompt): with gr.Blocks(analytics_enabled=False) as inspiration: - flag = os.path.exists(inspiration_path) + flag = os.path.exists(opts.inspiration_dir) if flag: - types = os.listdir(inspiration_path) - types = [x for x in types if x != "system"] + types = get_types_list() flag = len(types) > 0 - if not flag: - os.mkdir(inspiration_path) + else: + os.makedirs(opts.inspiration_dir) + if not flag: gr.HTML(""" -
" +

To activate inspiration function, you need get "inspiration" images first.


+ You can create these images by run "Create inspiration images" script in txt2img page,
you can get the artists or art styles list from here
+ https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
+ download these files, and select these files in the "Create inspiration images" script UI
+ There about 6000 artists and art styles in these files.
This takes server hours depending on your GPU type and how many pictures you generate for each artist/style +
I suggest at least four images for each


+

You can also download generated pictures from here:


+ https://huggingface.co/datasets/yfszzx/inspiration
+ unzip the file to the project directory of webui
+ and restart webui, and enjoy the joy of creation!
""") return inspiration if not os.path.exists(inspiration_system_path): os.mkdir(inspiration_system_path) - gallery, names = get_inspiration_images("Exclude abandoned", types[0]) with gr.Row(): with gr.Column(scale=2): - inspiration_gallery = gr.Gallery(gallery, show_label=False, elem_id="inspiration_gallery").style(grid=5, height='auto') + inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto') with gr.Column(scale=1): - types = gr.Dropdown(choices=types, value=types[0], label="Type", visible=len(types) > 1) + print(types) + types = gr.CheckboxGroup(choices=types, value=types) + keyword = gr.Textbox("", label="Key word") with gr.Row(): source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") - get_inspiration = gr.Button("Get inspiration") + get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button") name = gr.Textbox(show_label=False, interactive=False) with gr.Row(): send_to_txt2img = gr.Button('to txt2img') send_to_img2img = gr.Button('to img2img') - style_gallery = gr.Gallery(show_label=False, elem_id="inspiration_style_gallery").style(grid=2, height='auto') - + style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') collect = gr.Button('Collect') - give_up = gr.Button("Don't show any more") + give_up = gr.Button("Don't show again") moveout = gr.Button("Move out", visible=False) - with gr.Row(): + warning = gr.HTML() + with gr.Row(visible=False): select_button = gr.Button('set button', elem_id="inspiration_select_button") - name_list = gr.State(names) - source.change(source_change, inputs=[source], outputs=[moveout]) - get_inspiration.click(get_inspiration_images, inputs=[source, types], outputs=[inspiration_gallery, name_list]) - select_button.click(select_click, _js="inspiration_selected", inputs=[name, types, name_list], outputs=[name, style_gallery]) - give_up.click(give_up_click, inputs=[name, types], outputs=None) - collect.click(collect_click, inputs=[name, types], outputs=None) + name_list = gr.State() + + get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list, keyword]) + source.change(source_change, inputs=[source], outputs=[moveout, style_gallery]) + source.change(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning]) + give_up.click(give_up_click, inputs=[name], outputs=[warning]) + collect.click(collect_click, inputs=[name], outputs=[warning]) + moveout.click(moveout_click, inputs=[name, source], outputs=[warning]) + send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt]) + send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt]) + send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None) + send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None) return inspiration diff --git a/modules/shared.py b/modules/shared.py index ae033710..564b1b8d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -316,6 +316,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) +options_templates.update(options_section(('inspiration', "Inspiration"), { + "inspiration_dir": OptionInfo("inspiration", "Directory of inspiration", component_args=hide_dirs), + "inspiration_max_samples": OptionInfo(4, "Maximum number of samples, used to determine which folders to skip when continue running the create script", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), + "inspiration_rows_num": OptionInfo(4, "Rows of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), + "inspiration_cols_num": OptionInfo(8, "Columns of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), +})) class Options: data = None diff --git a/modules/ui.py b/modules/ui.py index 6a0a3c3b..b651eb9c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1180,7 +1180,7 @@ def create_ui(wrap_gradio_gpu_call): } browser_interface = images_history.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) - inspiration_interface = inspiration.ui(gr, opts) + inspiration_interface = inspiration.ui(gr, opts, txt2img_prompt, img2img_prompt) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): -- cgit v1.2.3 From 9ba372de90df81c4f1e992d8b33ae17c6630de95 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 13:55:42 -0500 Subject: initial work on getting prompts cleared on the backend and synchronizing token counter --- modules/ui.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d2cb528e..2748a2e0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,15 +429,16 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -# setup button for clearing prompt input boxes on client side of webui -def connect_trash_prompt(dummy_component, button, is_img2img): +def clear_prompt(prompt): + print(f"type: '{prompt}'") + print(prompt) + + # update_token_counter(prompt, steps) + return "" + +def connect_trash_prompt(prompt, confirmed): + if(confirmed): clear_prompt(prompt) - button.click( - fn=lambda: print("Clearing prompt"), - _js="trash_prompt", - inputs=[], - outputs=[], - ) def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): """ Connects a 'reuse (sub)seed' button's click event so that it copies last used @@ -640,7 +641,16 @@ def create_ui(wrap_gradio_gpu_call): dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - connect_trash_prompt(dummy_component, trash_prompt_button, False) + + + trash_prompt_button.click( + # fn=lambda: print("Clearing prompt"), + _js="trash_prompt", + fn=lambda: clear_prompt(), + inputs=[txt2img_prompt, dummy_component], + outputs=[txt2img_prompt, dummy_component], + ) + with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): @@ -848,7 +858,6 @@ def create_ui(wrap_gradio_gpu_call): img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) - connect_trash_prompt(dummy_component,trash_prompt_button, True) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.3 From ee0505dd0092ae7073b77aba93a858bda000dc60 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 14:24:14 -0500 Subject: only delete prompt on back end and remove client-side deletion --- modules/ui.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 2748a2e0..90c338da 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,15 +429,21 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -def clear_prompt(prompt): - print(f"type: '{prompt}'") - print(prompt) - # update_token_counter(prompt, steps) - return "" +def connect_trash_prompt(_, confirmed): + if(confirmed): + # update_token_counter(prompt, steps) + return ["", confirmed] -def connect_trash_prompt(prompt, confirmed): - if(confirmed): clear_prompt(prompt) +def trash_prompt_click(button, prompt): + dummy_component = gradio.Label(visible=False) + + button.click( + _js="trash_prompt", + fn=connect_trash_prompt, + inputs=[prompt, dummy_component], + outputs=[prompt, dummy_component], + ) def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): @@ -643,13 +649,7 @@ def create_ui(wrap_gradio_gpu_call): txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - trash_prompt_button.click( - # fn=lambda: print("Clearing prompt"), - _js="trash_prompt", - fn=lambda: clear_prompt(), - inputs=[txt2img_prompt, dummy_component], - outputs=[txt2img_prompt, dummy_component], - ) + trash_prompt_click(trash_prompt_button, txt2img_prompt) with gr.Row(elem_id='txt2img_progress_row'): @@ -858,6 +858,7 @@ def create_ui(wrap_gradio_gpu_call): img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) + trash_prompt_click(trash_prompt_button, img2img_prompt) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.3 From de70ddaf58fae98c561738a54f574abfa14cd8d1 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 15:00:35 -0500 Subject: update token counter when clearing prompt --- modules/ui.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 90c338da..d3a89bf7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -430,19 +430,16 @@ def create_seed_inputs(): -def connect_trash_prompt(_, confirmed): +def connect_trash_prompt(_prompt, confirmed, _token_counter): if(confirmed): - # update_token_counter(prompt, steps) - return ["", confirmed] - -def trash_prompt_click(button, prompt): - dummy_component = gradio.Label(visible=False) + return ["", confirmed, update_token_counter("", 1)] +def trash_prompt_click(button, prompt, _dummy_confirmed, token_counter): button.click( _js="trash_prompt", fn=connect_trash_prompt, - inputs=[prompt, dummy_component], - outputs=[prompt, dummy_component], + inputs=[prompt, _dummy_confirmed, token_counter], + outputs=[prompt, _dummy_confirmed, token_counter], ) @@ -649,7 +646,6 @@ def create_ui(wrap_gradio_gpu_call): txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - trash_prompt_click(trash_prompt_button, txt2img_prompt) with gr.Row(elem_id='txt2img_progress_row'): @@ -720,6 +716,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + trash_prompt_click(trash_prompt_button, txt2img_prompt, dummy_component, token_counter) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), @@ -858,7 +855,6 @@ def create_ui(wrap_gradio_gpu_call): img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) - trash_prompt_click(trash_prompt_button, img2img_prompt) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -958,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + trash_prompt_click(trash_prompt_button, img2img_prompt, dummy_component, token_counter) img2img_prompt_img.change( fn=modules.images.image_data, -- cgit v1.2.3 From 9e40520f00d836cfa93187f7f1e81e2a7bd100b9 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 15:13:12 -0500 Subject: refactor internal terminology to use 'clear' instead of 'trash' like #2728 --- modules/shared.py | 2 +- modules/ui.py | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 1585d532..ab5a0e9a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -317,7 +317,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), - "trash_prompt_visible": OptionInfo(True, "Show trash prompt button"), + "clear_prompt_visible": OptionInfo(True, "Show clear prompt button"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index d3a89bf7..31150800 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -88,7 +88,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 -trash_prompt_symbol = '\U0001F5D1' # +clear_prompt_symbol = '\U0001F5D1' # 🗑️ def plaintext_to_html(text): @@ -430,14 +430,14 @@ def create_seed_inputs(): -def connect_trash_prompt(_prompt, confirmed, _token_counter): +def clear_prompt(_prompt, confirmed, _token_counter): if(confirmed): return ["", confirmed, update_token_counter("", 1)] -def trash_prompt_click(button, prompt, _dummy_confirmed, token_counter): +def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter): button.click( - _js="trash_prompt", - fn=connect_trash_prompt, + _js="clear_prompt", + fn=clear_prompt, inputs=[prompt, _dummy_confirmed, token_counter], outputs=[prompt, _dummy_confirmed, token_counter], ) @@ -518,7 +518,7 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt", visible=opts.trash_prompt_visible) + clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id="clear_prompt", visible=opts.clear_prompt_visible) token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") @@ -559,7 +559,7 @@ def create_toprow(is_img2img): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style2.save_to_config = True - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, trash_prompt + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, clear_prompt_button def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -640,7 +640,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\ txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\ - token_button, trash_prompt_button = create_toprow(is_img2img=False) + token_button, clear_prompt_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -716,7 +716,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - trash_prompt_click(trash_prompt_button, txt2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, txt2img_prompt, dummy_component, token_counter) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), @@ -853,7 +853,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\ img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ - token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) + token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): @@ -954,7 +954,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - trash_prompt_click(trash_prompt_button, img2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, img2img_prompt, dummy_component, token_counter) img2img_prompt_img.change( fn=modules.images.image_data, -- cgit v1.2.3 From 0c7cf08b3d27a61bab4cd8b16f8be8ae74879423 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 15:32:26 -0500 Subject: some doc and formatting --- modules/ui.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 31150800..b26cf10a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -88,7 +88,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ +clear_prompt_symbol = '\U0001F5D1' # 🗑️ def plaintext_to_html(text): @@ -429,12 +429,14 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - def clear_prompt(_prompt, confirmed, _token_counter): - if(confirmed): - return ["", confirmed, update_token_counter("", 1)] + """Given confirmation from a user on the client-side, go ahead with clearing prompt""" + if confirmed: + return ["", confirmed, update_token_counter("", 1)] + def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter): + """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" button.click( _js="clear_prompt", fn=clear_prompt, @@ -518,7 +520,12 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id="clear_prompt", visible=opts.clear_prompt_visible) + + clear_prompt_button = gr.Button( + value=clear_prompt_symbol, + elem_id="clear_prompt", + visible=opts.clear_prompt_visible + ) token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") -- cgit v1.2.3 From 57eb54b838faa383c10079e1bb5471b7bee6a695 Mon Sep 17 00:00:00 2001 From: Extraltodeus Date: Sat, 22 Oct 2022 00:11:07 +0200 Subject: implement CUDA device selection by ID --- modules/devices.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index eb422583..8a159282 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,7 +1,6 @@ +import sys, os, shlex import contextlib - import torch - from modules import errors # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility @@ -9,10 +8,26 @@ has_mps = getattr(torch, 'has_mps', False) cpu = torch.device("cpu") +def extract_device_id(args, name): + for x in range(len(args)): + if name in args[x]: return args[x+1] + return None def get_optimal_device(): if torch.cuda.is_available(): - return torch.device("cuda") + # CUDA device selection support: + if "shared" not in sys.modules: + commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop. + sys.argv += shlex.split(commandline_args) + device_id = extract_device_id(sys.argv, '--device-id') + else: + device_id = shared.cmd_opts.device_id + + if device_id is not None: + cuda_device = f"cuda:{device_id}" + return torch.device(cuda_device) + else: + return torch.device("cuda") if has_mps: return torch.device("mps") -- cgit v1.2.3 From 29bfacd63cb5c73b9643d94f255cca818fd49d9c Mon Sep 17 00:00:00 2001 From: Extraltodeus Date: Sat, 22 Oct 2022 00:12:46 +0200 Subject: implement CUDA device selection, --device-id arg --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 41d7f08e..03032a47 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -80,6 +80,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) cmd_opts = parser.parse_args() restricted_opts = [ -- cgit v1.2.3 From 700340448baa7412c7cc5ff3d1349ac79ee8ed0c Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 17:24:04 -0500 Subject: forgot to clear neg prompt after moving to back. Add tooltip to hints --- modules/ui.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index b26cf10a..25aeba3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,19 +429,19 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -def clear_prompt(_prompt, confirmed, _token_counter): +def clear_prompt(_prompt, _prompt_neg, confirmed, _token_counter): """Given confirmation from a user on the client-side, go ahead with clearing prompt""" if confirmed: - return ["", confirmed, update_token_counter("", 1)] + return ["", "", confirmed, update_token_counter("", 1)] -def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter): +def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter): """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" button.click( _js="clear_prompt", fn=clear_prompt, - inputs=[prompt, _dummy_confirmed, token_counter], - outputs=[prompt, _dummy_confirmed, token_counter], + inputs=[prompt, prompt_neg, _dummy_confirmed, token_counter], + outputs=[prompt, prompt_neg, _dummy_confirmed, token_counter], ) @@ -723,7 +723,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - connect_clear_prompt(clear_prompt_button, txt2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, txt2img_prompt, txt2img_negative_prompt, dummy_component, token_counter) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), @@ -961,7 +961,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - connect_clear_prompt(clear_prompt_button, img2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, img2img_prompt, img2img_negative_prompt, dummy_component, token_counter) img2img_prompt_img.change( fn=modules.images.image_data, -- cgit v1.2.3 From 40ddb6df61564684263c7442bacf61efe3882b87 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sat, 22 Oct 2022 10:16:22 +0800 Subject: inspiration perfected --- modules/inspiration.py | 71 +++++++++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 32 deletions(-) (limited to 'modules') diff --git a/modules/inspiration.py b/modules/inspiration.py index f72ebf3a..319183ab 100644 --- a/modules/inspiration.py +++ b/modules/inspiration.py @@ -13,7 +13,7 @@ def read_name_list(file, types=None, keyword=None): line = line.rstrip("\n") if types is not None: dirname = os.path.split(line) - if dirname[0] in types and keyword in dirname[1]: + if dirname[0] in types and keyword in dirname[1].lower(): ret.append(line) else: ret.append(line) @@ -21,8 +21,10 @@ def read_name_list(file, types=None, keyword=None): return ret def save_name_list(file, name): - with open(file, "a") as f: - f.write(name + "\n") + name_list = read_name_list(file) + if name not in name_list: + with open(file, "a") as f: + f.write(name + "\n") def get_types_list(): files = os.listdir(opts.inspiration_dir) @@ -39,20 +41,20 @@ def get_types_list(): return types def get_inspiration_images(source, types, keyword): + keyword = keyword.strip(" ").lower() get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num) if source == "Favorites": names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword) names = random.sample(names, get_num) if len(names) > get_num else names elif source == "Abandoned": names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) - print(names) names = random.sample(names, get_num) if len(names) > get_num else names elif source == "Exclude abandoned": abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) all_names = [] for tp in types: name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) - all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()] if len(all_names) > get_num: names = [] @@ -66,14 +68,14 @@ def get_inspiration_images(source, types, keyword): all_names = [] for tp in types: name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) - all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()] names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names image_list = [] for a in names: image_path = os.path.join(opts.inspiration_dir, a) images = os.listdir(image_path) image_list.append((os.path.join(image_path, random.choice(images)), a)) - return image_list, names, "" + return image_list, names def select_click(index, name_list): name = name_list[int(index)] @@ -83,22 +85,18 @@ def select_click(index, name_list): def give_up_click(name): file = os.path.join(inspiration_system_path, "abandoned.txt") - name_list = read_name_list(file) - if name not in name_list: - save_name_list(file, name) + save_name_list(file, name) return "Added to abandoned list" def collect_click(name): file = os.path.join(inspiration_system_path, "faverites.txt") - name_list = read_name_list(file) - if name not in name_list: - save_name_list(file, name) + save_name_list(file, name) return "Added to faverite list" def moveout_click(name, source): if source == "Abandoned": file = os.path.join(inspiration_system_path, "abandoned.txt") - if source == "Favorites": + elif source == "Favorites": file = os.path.join(inspiration_system_path, "faverites.txt") else: return None @@ -107,8 +105,8 @@ def moveout_click(name, source): with open(file, "a") as f: for a in name_list: if a != name: - f.write(a) - return "Moved out {name} from {source} list" + f.write(a + "\n") + return f"Moved out {name} from {source} list" def source_change(source): if source in ["Abandoned", "Favorites"]: @@ -116,10 +114,12 @@ def source_change(source): else: return gradio.update(visible=False), [] def add_to_prompt(name, prompt): - print(name, prompt) name = os.path.basename(name) return prompt + "," + name +def clear_keyword(): + return "" + def ui(gr, opts, txt2img_prompt, img2img_prompt): with gr.Blocks(analytics_enabled=False) as inspiration: flag = os.path.exists(opts.inspiration_dir) @@ -132,15 +132,15 @@ def ui(gr, opts, txt2img_prompt, img2img_prompt): gr.HTML("""

To activate inspiration function, you need get "inspiration" images first.


You can create these images by run "Create inspiration images" script in txt2img page,
you can get the artists or art styles list from here
- https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
+ https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
download these files, and select these files in the "Create inspiration images" script UI
There about 6000 artists and art styles in these files.
This takes server hours depending on your GPU type and how many pictures you generate for each artist/style
I suggest at least four images for each


You can also download generated pictures from here:


- https://huggingface.co/datasets/yfszzx/inspiration
+ https://huggingface.co/datasets/yfszzx/inspiration
unzip the file to the project directory of webui
and restart webui, and enjoy the joy of creation!
- """) + """) return inspiration if not os.path.exists(inspiration_system_path): os.mkdir(inspiration_system_path) @@ -148,35 +148,42 @@ def ui(gr, opts, txt2img_prompt, img2img_prompt): with gr.Column(scale=2): inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto') with gr.Column(scale=1): - print(types) types = gr.CheckboxGroup(choices=types, value=types) - keyword = gr.Textbox("", label="Key word") - with gr.Row(): - source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") - get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button") - name = gr.Textbox(show_label=False, interactive=False) with gr.Row(): + source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") + keyword = gr.Textbox("", label="Key word") + get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button") + name = gr.Textbox(show_label=False, interactive=False) + with gr.Row(): send_to_txt2img = gr.Button('to txt2img') send_to_img2img = gr.Button('to img2img') style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') - collect = gr.Button('Collect') - give_up = gr.Button("Don't show again") - moveout = gr.Button("Move out", visible=False) warning = gr.HTML() + with gr.Row(): + collect = gr.Button('Collect') + give_up = gr.Button("Don't show again") + moveout = gr.Button("Move out", visible=False) + with gr.Row(visible=False): select_button = gr.Button('set button', elem_id="inspiration_select_button") name_list = gr.State() - get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list, keyword]) - source.change(source_change, inputs=[source], outputs=[moveout, style_gallery]) - source.change(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list]) keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + source.change(source_change, inputs=[source], outputs=[moveout, style_gallery]) + source.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword]) + types.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword]) + select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning]) give_up.click(give_up_click, inputs=[name], outputs=[warning]) collect.click(collect_click, inputs=[name], outputs=[warning]) moveout.click(moveout_click, inputs=[name, source], outputs=[warning]) + moveout.click(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt]) send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt]) + send_to_txt2img.click(collect_click, inputs=[name], outputs=[warning]) + send_to_img2img.click(collect_click, inputs=[name], outputs=[warning]) send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None) send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None) return inspiration -- cgit v1.2.3 From d93ea5cdeb2fd3607b7265271ccab2c9bf4c1156 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sat, 22 Oct 2022 10:21:21 +0800 Subject: inspiration perfected --- modules/inspiration.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/inspiration.py b/modules/inspiration.py index 319183ab..94ff139a 100644 --- a/modules/inspiration.py +++ b/modules/inspiration.py @@ -73,8 +73,11 @@ def get_inspiration_images(source, types, keyword): image_list = [] for a in names: image_path = os.path.join(opts.inspiration_dir, a) - images = os.listdir(image_path) - image_list.append((os.path.join(image_path, random.choice(images)), a)) + images = os.listdir(image_path) + if len(images) > 0: + image_list.append((os.path.join(image_path, random.choice(images)), a)) + else: + print(image_path) return image_list, names def select_click(index, name_list): -- cgit v1.2.3 From 67b78f0ea6f196bfdca49932da062631bb40d0b1 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sat, 22 Oct 2022 10:29:23 +0800 Subject: inspiration perfected --- modules/inspiration.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/inspiration.py b/modules/inspiration.py index 94ff139a..29cf8297 100644 --- a/modules/inspiration.py +++ b/modules/inspiration.py @@ -160,12 +160,13 @@ def ui(gr, opts, txt2img_prompt, img2img_prompt): with gr.Row(): send_to_txt2img = gr.Button('to txt2img') send_to_img2img = gr.Button('to img2img') - style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') - warning = gr.HTML() - with gr.Row(): collect = gr.Button('Collect') give_up = gr.Button("Don't show again") - moveout = gr.Button("Move out", visible=False) + moveout = gr.Button("Move out", visible=False) + warning = gr.HTML() + style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') + + with gr.Row(visible=False): select_button = gr.Button('set button', elem_id="inspiration_select_button") -- cgit v1.2.3 From 2b91251637078e04472c91a06a8d9c4db9c1dcf0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 12:23:45 +0300 Subject: removed aesthetic gradients as built-in added support for extensions --- modules/aesthetic_clip.py | 241 -------------------------------------------- modules/images_history.py | 2 +- modules/img2img.py | 5 +- modules/processing.py | 35 ++++--- modules/script_callbacks.py | 42 ++++++++ modules/scripts.py | 210 ++++++++++++++++++++++++++++---------- modules/sd_hijack.py | 1 - modules/sd_models.py | 7 +- modules/shared.py | 19 ---- modules/txt2img.py | 5 +- modules/ui.py | 83 +++------------ 11 files changed, 244 insertions(+), 406 deletions(-) delete mode 100644 modules/aesthetic_clip.py create mode 100644 modules/script_callbacks.py (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py deleted file mode 100644 index 8c828541..00000000 --- a/modules/aesthetic_clip.py +++ /dev/null @@ -1,241 +0,0 @@ -import copy -import itertools -import os -from pathlib import Path -import html -import gc - -import gradio as gr -import torch -from PIL import Image -from torch import optim - -from modules import shared -from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer -from tqdm.auto import tqdm, trange -from modules.shared import opts, device - - -def get_all_images_in_folder(folder): - return [os.path.join(folder, f) for f in os.listdir(folder) if - os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)] - - -def check_is_valid_image_file(filename): - return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp")) - - -def batched(dataset, total, n=1): - for ndx in range(0, total, n): - yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] - - -def iter_to_batched(iterable, n=1): - it = iter(iterable) - while True: - chunk = tuple(itertools.islice(it, n)) - if not chunk: - return - yield chunk - - -def create_ui(): - import modules.ui - - with gr.Group(): - with gr.Accordion("Open for Clip Aesthetic!", open=False): - with gr.Row(): - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", - value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) - - with gr.Row(): - aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', - placeholder="Aesthetic learning rate", value="0.0001") - aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) - aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()), - label="Aesthetic imgs embedding", - value="None") - - modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings") - - with gr.Row(): - aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', - placeholder="This text is used to rotate the feature space of the imgs embs", - value="") - aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01, - value=0.1) - aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) - - return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative - - -aesthetic_clip_model = None - - -def aesthetic_clip(): - global aesthetic_clip_model - - if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path: - aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path) - aesthetic_clip_model.cpu() - - return aesthetic_clip_model - - -def generate_imgs_embd(name, folder, batch_size): - model = aesthetic_clip().to(device) - processor = CLIPProcessor.from_pretrained(model.name_or_path) - - with torch.no_grad(): - embs = [] - for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size), - desc=f"Generating embeddings for {name}"): - if shared.state.interrupted: - break - inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device) - outputs = model.get_image_features(**inputs).cpu() - embs.append(torch.clone(outputs)) - inputs.to("cpu") - del inputs, outputs - - embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) - - # The generated embedding will be located here - path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") - torch.save(embs, path) - - model.cpu() - del processor - del embs - gc.collect() - torch.cuda.empty_cache() - res = f""" - Done generating embedding for {name}! - Aesthetic embedding saved to {html.escape(path)} - """ - shared.update_aesthetic_embeddings() - return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", - value="None"), \ - gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), - label="Imgs embedding", - value="None"), res, "" - - -def slerp(low, high, val): - low_norm = low / torch.norm(low, dim=1, keepdim=True) - high_norm = high / torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm * high_norm).sum(1)) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high - return res - - -class AestheticCLIP: - def __init__(self): - self.skip = False - self.aesthetic_steps = 0 - self.aesthetic_weight = 0 - self.aesthetic_lr = 0 - self.slerp = False - self.aesthetic_text_negative = "" - self.aesthetic_slerp_angle = 0 - self.aesthetic_imgs_text = "" - - self.image_embs_name = None - self.image_embs = None - self.load_image_embs(None) - - def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, - aesthetic_slerp=True, aesthetic_imgs_text="", - aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False): - self.aesthetic_imgs_text = aesthetic_imgs_text - self.aesthetic_slerp_angle = aesthetic_slerp_angle - self.aesthetic_text_negative = aesthetic_text_negative - self.slerp = aesthetic_slerp - self.aesthetic_lr = aesthetic_lr - self.aesthetic_weight = aesthetic_weight - self.aesthetic_steps = aesthetic_steps - self.load_image_embs(image_embs_name) - - if self.image_embs_name is not None: - p.extra_generation_params.update({ - "Aesthetic LR": aesthetic_lr, - "Aesthetic weight": aesthetic_weight, - "Aesthetic steps": aesthetic_steps, - "Aesthetic embedding": self.image_embs_name, - "Aesthetic slerp": aesthetic_slerp, - "Aesthetic text": aesthetic_imgs_text, - "Aesthetic text negative": aesthetic_text_negative, - "Aesthetic slerp angle": aesthetic_slerp_angle, - }) - - def set_skip(self, skip): - self.skip = skip - - def load_image_embs(self, image_embs_name): - if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": - image_embs_name = None - self.image_embs_name = None - if image_embs_name is not None and self.image_embs_name != image_embs_name: - self.image_embs_name = image_embs_name - self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) - self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) - self.image_embs.requires_grad_(False) - - def __call__(self, z, remade_batch_tokens): - if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None: - tokenizer = shared.sd_model.cond_stage_model.tokenizer - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [ - [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in - remade_batch_tokens] - - tokens = torch.asarray(remade_batch_tokens).to(device) - - model = copy.deepcopy(aesthetic_clip()).to(device) - model.requires_grad_(True) - if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: - text_embs_2 = model.get_text_features( - **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) - if self.aesthetic_text_negative: - text_embs_2 = self.image_embs - text_embs_2 - text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) - img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) - else: - img_embs = self.image_embs - - with torch.enable_grad(): - - # We optimize the model to maximize the similarity - optimizer = optim.Adam( - model.text_model.parameters(), lr=self.aesthetic_lr - ) - - for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"): - text_embs = model.get_text_features(input_ids=tokens) - text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ img_embs.T - loss = -sim - optimizer.zero_grad() - loss.mean().backward() - optimizer.step() - - zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - if opts.CLIP_stop_at_last_layers > 1: - zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] - zn = model.text_model.final_layer_norm(zn) - else: - zn = zn.last_hidden_state - model.cpu() - del model - gc.collect() - torch.cuda.empty_cache() - zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1) - if self.slerp: - z = slerp(z, zn, self.aesthetic_weight) - else: - z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight - - return z diff --git a/modules/images_history.py b/modules/images_history.py index 78fd0543..bc5cf11f 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -310,7 +310,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): forward = gr.Button('Prev batch') backward = gr.Button('Next batch') with gr.Column(scale=3): - load_info = gr.HTML(visible=not custom_dir) + load_info = gr.HTML(visible=not custom_dir) with gr.Row(visible=False) as warning: warning_box = gr.Textbox("Message", interactive=False) diff --git a/modules/img2img.py b/modules/img2img.py index eea5199b..8d9f7cf9 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_inpaint = mode == 1 is_batch = mode == 2 @@ -109,7 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpainting_mask_invert=inpainting_mask_invert, ) - shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index ff1ec4c9..372489f7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -104,6 +104,12 @@ class StableDiffusionProcessing(): self.seed_resize_from_h = 0 self.seed_resize_from_w = 0 + self.scripts = None + self.script_args = None + self.all_prompts = None + self.all_seeds = None + self.all_subseeds = None + def init(self, all_prompts, all_seeds, all_subseeds): pass @@ -350,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed: shared.prompt_styles.apply_styles(p) if type(p.prompt) == list: - all_prompts = p.prompt + p.all_prompts = p.prompt else: - all_prompts = p.batch_size * p.n_iter * [p.prompt] + p.all_prompts = p.batch_size * p.n_iter * [p.prompt] if type(seed) == list: - all_seeds = seed + p.all_seeds = seed else: - all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] + p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))] if type(subseed) == list: - all_subseeds = subseed + p.all_subseeds = subseed else: - all_subseeds = [int(subseed) + x for x in range(len(all_prompts))] + p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] def infotext(iteration=0, position_in_batch=0): - return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) + return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() + if p.scripts is not None: + p.scripts.run_alwayson_scripts(p) + infotexts = [] output_images = [] with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): - p.init(all_prompts, all_seeds, all_subseeds) + p.init(p.all_prompts, p.all_seeds, p.all_subseeds) if state.job_count == -1: state.job_count = p.n_iter @@ -387,9 +396,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.interrupted: break - prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size] - seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size] - subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] + seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] + subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] if (len(prompts) == 0): break @@ -490,10 +499,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: index_of_first_image = 1 if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) devices.torch_gc() - return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py new file mode 100644 index 00000000..866b7acd --- /dev/null +++ b/modules/script_callbacks.py @@ -0,0 +1,42 @@ + +callbacks_model_loaded = [] +callbacks_ui_tabs = [] + + +def clear_callbacks(): + callbacks_model_loaded.clear() + callbacks_ui_tabs.clear() + + +def model_loaded_callback(sd_model): + for callback in callbacks_model_loaded: + callback(sd_model) + + +def ui_tabs_callback(): + res = [] + + for callback in callbacks_ui_tabs: + res += callback() or [] + + return res + + +def on_model_loaded(callback): + """register a function to be called when the stable diffusion model is created; the model is + passed as an argument""" + callbacks_model_loaded.append(callback) + + +def on_ui_tabs(callback): + """register a function to be called when the UI is creating new tabs. + The function must either return a None, which means no new tabs to be added, or a list, where + each element is a tuple: + (gradio_component, title, elem_id) + + gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks) + title is tab text displayed to user in the UI + elem_id is HTML id for the tab + """ + callbacks_ui_tabs.append(callback) + diff --git a/modules/scripts.py b/modules/scripts.py index 1039fa9c..65f25f49 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,86 +1,153 @@ import os import sys import traceback +from collections import namedtuple import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared +from modules import shared, paths, script_callbacks + +AlwaysVisible = object() + class Script: filename = None args_from = None args_to = None + alwayson = False + + infotext_fields = None + """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when + parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example + """ - # The title of the script. This is what will be displayed in the dropdown menu. def title(self): + """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" + raise NotImplementedError() - # How the script is displayed in the UI. See https://gradio.app/docs/#components - # for the different UI components you can use and how to create them. - # Most UI components can return a value, such as a boolean for a checkbox. - # The returned values are passed to the run method as parameters. def ui(self, is_img2img): + """this function should create gradio UI elements. See https://gradio.app/docs/#components + The return value should be an array of all components that are used in processing. + Values of those returned componenbts will be passed to run() and process() functions. + """ + pass - # Determines when the script should be shown in the dropdown menu via the - # returned value. As an example: - # is_img2img is True if the current tab is img2img, and False if it is txt2img. - # Thus, return is_img2img to only show the script on the img2img tab. def show(self, is_img2img): + """ + is_img2img is True if this function is called for the img2img interface, and Fasle otherwise + + This function should return: + - False if the script should not be shown in UI at all + - True if the script should be shown in UI if it's scelected in the scripts drowpdown + - script.AlwaysVisible if the script should be shown in UI at all times + """ + return True - # This is where the additional processing is implemented. The parameters include - # self, the model object "p" (a StableDiffusionProcessing class, see - # processing.py), and the parameters returned by the ui method. - # Custom functions can be defined here, and additional libraries can be imported - # to be used in processing. The return value should be a Processed object, which is - # what is returned by the process_images method. - def run(self, *args): + def run(self, p, *args): + """ + This function is called if the script has been selected in the script dropdown. + It must do all processing and return the Processed object with results, same as + one returned by processing.process_images. + + Usually the processing is done by calling the processing.process_images function. + + args contains all values returned by components from ui() + """ + raise NotImplementedError() - # The description method is currently unused. - # To add a description that appears when hovering over the title, amend the "titles" - # dict in script.js to include the script title (returned by title) as a key, and - # your description as the value. + def process(self, p, *args): + """ + This function is called before processing begins for AlwaysVisible scripts. + scripts. You can modify the processing object (p) here, inject hooks, etc. + """ + + pass + def describe(self): + """unused""" return "" +current_basedir = paths.script_path + + +def basedir(): + """returns the base directory for the current script. For scripts in the main scripts directory, + this is the main directory (where webui.py resides), and for scripts in extensions directory + (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic) + """ + return current_basedir + + scripts_data = [] +ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) +ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"]) + + +def list_scripts(scriptdirname, extension): + scripts_list = [] + + basedir = os.path.join(paths.script_path, scriptdirname) + if os.path.exists(basedir): + for filename in sorted(os.listdir(basedir)): + scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) + + extdir = os.path.join(paths.script_path, "extensions") + if os.path.exists(extdir): + for dirname in sorted(os.listdir(extdir)): + dirpath = os.path.join(extdir, dirname) + if not os.path.isdir(dirpath): + continue + for filename in sorted(os.listdir(os.path.join(dirpath, scriptdirname))): + scripts_list.append(ScriptFile(dirpath, filename, os.path.join(dirpath, scriptdirname, filename))) -def load_scripts(basedir): - if not os.path.exists(basedir): - return + scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] - for filename in sorted(os.listdir(basedir)): - path = os.path.join(basedir, filename) + return scripts_list - if os.path.splitext(path)[1].lower() != '.py': - continue - if not os.path.isfile(path): - continue +def load_scripts(): + global current_basedir + scripts_data.clear() + script_callbacks.clear_callbacks() + + scripts_list = list_scripts("scripts", ".py") + + syspath = sys.path + for scriptfile in sorted(scripts_list): try: - with open(path, "r", encoding="utf8") as file: + if scriptfile.basedir != paths.script_path: + sys.path = [scriptfile.basedir] + sys.path + current_basedir = scriptfile.basedir + + with open(scriptfile.path, "r", encoding="utf8") as file: text = file.read() from types import ModuleType - compiled = compile(text, path, 'exec') - module = ModuleType(filename) + compiled = compile(text, scriptfile.path, 'exec') + module = ModuleType(scriptfile.filename) exec(compiled, module.__dict__) for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append((script_class, path)) + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir)) except Exception: - print(f"Error loading script: {filename}", file=sys.stderr) + print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + finally: + sys.path = syspath + current_basedir = paths.script_path + def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: @@ -96,56 +163,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): class ScriptRunner: def __init__(self): self.scripts = [] + self.selectable_scripts = [] + self.alwayson_scripts = [] self.titles = [] + self.infotext_fields = [] def setup_ui(self, is_img2img): - for script_class, path in scripts_data: + for script_class, path, basedir in scripts_data: script = script_class() script.filename = path - if not script.show(is_img2img): - continue + visibility = script.show(is_img2img) - self.scripts.append(script) + if visibility == AlwaysVisible: + self.scripts.append(script) + self.alwayson_scripts.append(script) + script.alwayson = True - self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] + elif visibility: + self.scripts.append(script) + self.selectable_scripts.append(script) - dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") - dropdown.save_to_config = True - inputs = [dropdown] + self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] + + inputs = [None] + inputs_alwayson = [True] - for script in self.scripts: + def create_script_ui(script, inputs, inputs_alwayson): script.args_from = len(inputs) script.args_to = len(inputs) controls = wrap_call(script.ui, script.filename, "ui", is_img2img) if controls is None: - continue + return for control in controls: control.custom_script_source = os.path.basename(script.filename) - control.visible = False + if not script.alwayson: + control.visible = False + + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields inputs += controls + inputs_alwayson += [script.alwayson for _ in controls] script.args_to = len(inputs) + for script in self.alwayson_scripts: + with gr.Group(): + create_script_ui(script, inputs, inputs_alwayson) + + dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") + dropdown.save_to_config = True + inputs[0] = dropdown + + for script in self.selectable_scripts: + create_script_ui(script, inputs, inputs_alwayson) + def select_script(script_index): - if 0 < script_index <= len(self.scripts): - script = self.scripts[script_index-1] + if 0 < script_index <= len(self.selectable_scripts): + script = self.selectable_scripts[script_index-1] args_from = script.args_from args_to = script.args_to else: args_from = 0 args_to = 0 - return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] + return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)] def init_field(title): if title == 'None': return script_index = self.titles.index(title) - script = self.scripts[script_index] + script = self.selectable_scripts[script_index] for i in range(script.args_from, script.args_to): inputs[i].visible = True @@ -164,7 +255,7 @@ class ScriptRunner: if script_index == 0: return None - script = self.scripts[script_index-1] + script = self.selectable_scripts[script_index-1] if script is None: return None @@ -176,6 +267,15 @@ class ScriptRunner: return processed + def run_alwayson_scripts(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.process(p, *script_args) + except Exception: + print(f"Error running alwayson script: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def reload_sources(self): for si, script in list(enumerate(self.scripts)): with open(script.filename, "r", encoding="utf8") as file: @@ -197,19 +297,21 @@ class ScriptRunner: self.scripts[si].args_from = args_from self.scripts[si].args_to = args_to + scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + def reload_script_body_only(): scripts_txt2img.reload_sources() scripts_img2img.reload_sources() -def reload_scripts(basedir): +def reload_scripts(): global scripts_txt2img, scripts_img2img - scripts_data.clear() - load_scripts(basedir) + load_scripts() scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 1f8587d1..0f10828e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -332,7 +332,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers.append([1.0] * 75) z1 = self.process_tokens(tokens, multipliers) - z1 = shared.aesthetic_clip(z1, remade_batch_tokens) z = z1 if z is None else torch.cat((z, z1), axis=-2) remade_batch_tokens = rem_tokens diff --git a/modules/sd_models.py b/modules/sd_models.py index d99dbce8..f9b3063d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -7,7 +7,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices +from modules import shared, modelloader, devices, script_callbacks from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -238,6 +238,9 @@ def load_model(checkpoint_info=None): sd_hijack.model_hijack.hijack(sd_model) sd_model.eval() + shared.sd_model = sd_model + + script_callbacks.model_loaded_callback(sd_model) print(f"Model loaded.") return sd_model @@ -252,7 +255,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model(checkpoint_info) + load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/modules/shared.py b/modules/shared.py index 0dbe360d..7d786f07 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -31,7 +31,6 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") -parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(models_path, 'aesthetic_embeddings'), help="aesthetic_embeddings directory(default: aesthetic_embeddings)") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") @@ -109,21 +108,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None - -os.makedirs(cmd_opts.aesthetic_embeddings_dir, exist_ok=True) -aesthetic_embeddings = {} - - -def update_aesthetic_embeddings(): - global aesthetic_embeddings - aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in - os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} - aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings) - - -update_aesthetic_embeddings() - - def reload_hypernetworks(): global hypernetworks @@ -415,9 +399,6 @@ sd_model = None clip_model = None -from modules.aesthetic_clip import AestheticCLIP -aesthetic_clip = AestheticCLIP() - progress_print_out = sys.stdout diff --git a/modules/txt2img.py b/modules/txt2img.py index 1761cfa2..c9d5a090 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -7,7 +7,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -36,7 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: firstphase_height=firstphase_height if enable_hr else None, ) - shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/ui.py b/modules/ui.py index 70a9cf10..c977482c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,10 +23,10 @@ import gradio as gr import gradio.utils import gradio.routes -from modules import sd_hijack, sd_models, localization +from modules import sd_hijack, sd_models, localization, script_callbacks from modules.paths import script_path -from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings +from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags @@ -44,7 +44,6 @@ from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.aesthetic_clip as aesthetic_clip import modules.images_history as img_his @@ -662,8 +661,6 @@ def create_ui(wrap_gradio_gpu_call): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui() - with gr.Group(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) @@ -718,14 +715,6 @@ def create_ui(wrap_gradio_gpu_call): denoising_strength, firstphase_width, firstphase_height, - aesthetic_lr, - aesthetic_weight, - aesthetic_steps, - aesthetic_imgs, - aesthetic_slerp, - aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative ] + custom_inputs, outputs=[ @@ -804,14 +793,7 @@ def create_ui(wrap_gradio_gpu_call): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), - (aesthetic_lr, "Aesthetic LR"), - (aesthetic_weight, "Aesthetic weight"), - (aesthetic_steps, "Aesthetic steps"), - (aesthetic_imgs, "Aesthetic embedding"), - (aesthetic_slerp, "Aesthetic slerp"), - (aesthetic_imgs_text, "Aesthetic text"), - (aesthetic_text_negative, "Aesthetic text negative"), - (aesthetic_slerp_angle, "Aesthetic slerp angle"), + *modules.scripts.scripts_txt2img.infotext_fields ] txt2img_preview_params = [ @@ -896,8 +878,6 @@ def create_ui(wrap_gradio_gpu_call): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui() - with gr.Group(): custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) @@ -988,14 +968,6 @@ def create_ui(wrap_gradio_gpu_call): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, - aesthetic_lr_im, - aesthetic_weight_im, - aesthetic_steps_im, - aesthetic_imgs_im, - aesthetic_slerp_im, - aesthetic_imgs_text_im, - aesthetic_slerp_angle_im, - aesthetic_text_negative_im, ] + custom_inputs, outputs=[ img2img_gallery, @@ -1087,14 +1059,7 @@ def create_ui(wrap_gradio_gpu_call): (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), (denoising_strength, "Denoising strength"), - (aesthetic_lr_im, "Aesthetic LR"), - (aesthetic_weight_im, "Aesthetic weight"), - (aesthetic_steps_im, "Aesthetic steps"), - (aesthetic_imgs_im, "Aesthetic embedding"), - (aesthetic_slerp_im, "Aesthetic slerp"), - (aesthetic_imgs_text_im, "Aesthetic text"), - (aesthetic_text_negative_im, "Aesthetic text negative"), - (aesthetic_slerp_angle_im, "Aesthetic slerp angle"), + *modules.scripts.scripts_img2img.infotext_fields ] token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) @@ -1217,9 +1182,9 @@ def create_ui(wrap_gradio_gpu_call): ) #images history images_history_switch_dict = { - "fn":modules.generation_parameters_copypaste.connect_paste, - "t2i":txt2img_paste_fields, - "i2i":img2img_paste_fields + "fn": modules.generation_parameters_copypaste.connect_paste, + "t2i": txt2img_paste_fields, + "i2i": img2img_paste_fields } images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) @@ -1264,18 +1229,6 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create embedding", variant='primary') - with gr.Tab(label="Create aesthetic images embedding"): - - new_embedding_name_ae = gr.Textbox(label="Name") - process_src_ae = gr.Textbox(label='Source directory') - batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256) - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding_ae = gr.Button(value="Create images embedding", variant='primary') - with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) @@ -1375,21 +1328,6 @@ def create_ui(wrap_gradio_gpu_call): ] ) - create_embedding_ae.click( - fn=aesthetic_clip.generate_imgs_embd, - inputs=[ - new_embedding_name_ae, - process_src_ae, - batch_ae - ], - outputs=[ - aesthetic_imgs, - aesthetic_imgs_im, - ti_output, - ti_outcome, - ] - ) - create_hypernetwork.click( fn=modules.hypernetworks.ui.create_hypernetwork, inputs=[ @@ -1580,10 +1518,10 @@ Requested path was: {f} if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() + oldval = opts.data.get(key, None) if cmd_opts.hide_ui_dir_config and key in restricted_opts: return gr.update(value=oldval), opts.dumpjson() - oldval = opts.data.get(key, None) opts.data[key] = value if oldval != value: @@ -1692,9 +1630,12 @@ Requested path was: {f} (images_history, "Image Browser", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), - (settings_interface, "Settings", "settings"), ] + interfaces += script_callbacks.ui_tabs_callback() + + interfaces += [(settings_interface, "Settings", "settings")] + with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: css = file.read() -- cgit v1.2.3 From 6398dc9b1049f242576ca309f95a3fb1e654951c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 13:34:49 +0300 Subject: further support for extensions --- modules/scripts.py | 44 +++++++++++++++++++++++++++++++++++--------- modules/ui.py | 19 ++++++++++--------- 2 files changed, 45 insertions(+), 18 deletions(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 65f25f49..9323af3e 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -102,17 +102,39 @@ def list_scripts(scriptdirname, extension): if os.path.exists(extdir): for dirname in sorted(os.listdir(extdir)): dirpath = os.path.join(extdir, dirname) - if not os.path.isdir(dirpath): + scriptdirpath = os.path.join(dirpath, scriptdirname) + + if not os.path.isdir(scriptdirpath): continue - for filename in sorted(os.listdir(os.path.join(dirpath, scriptdirname))): - scripts_list.append(ScriptFile(dirpath, filename, os.path.join(dirpath, scriptdirname, filename))) + for filename in sorted(os.listdir(scriptdirpath)): + scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename))) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] return scripts_list +def list_files_with_name(filename): + res = [] + + dirs = [paths.script_path] + + extdir = os.path.join(paths.script_path, "extensions") + if os.path.exists(extdir): + dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))] + + for dirpath in dirs: + if not os.path.isdir(dirpath): + continue + + path = os.path.join(dirpath, filename) + if os.path.isfile(filename): + res.append(path) + + return res + + def load_scripts(): global current_basedir scripts_data.clear() @@ -276,7 +298,7 @@ class ScriptRunner: print(f"Error running alwayson script: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - def reload_sources(self): + def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): with open(script.filename, "r", encoding="utf8") as file: args_from = script.args_from @@ -286,9 +308,12 @@ class ScriptRunner: from types import ModuleType - compiled = compile(text, filename, 'exec') - module = ModuleType(script.filename) - exec(compiled, module.__dict__) + module = cache.get(filename, None) + if module is None: + compiled = compile(text, filename, 'exec') + module = ModuleType(script.filename) + exec(compiled, module.__dict__) + cache[filename] = module for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): @@ -303,8 +328,9 @@ scripts_img2img = ScriptRunner() def reload_script_body_only(): - scripts_txt2img.reload_sources() - scripts_img2img.reload_sources() + cache = {} + scripts_txt2img.reload_sources(cache) + scripts_img2img.reload_sources(cache) def reload_scripts(): diff --git a/modules/ui.py b/modules/ui.py index c977482c..29986124 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1636,13 +1636,15 @@ Requested path was: {f} interfaces += [(settings_interface, "Settings", "settings")] - with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: - css = file.read() + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" if os.path.exists(os.path.join(script_path, "user.css")): with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - usercss = file.read() - css += usercss + css += file.read() + "\n" if not cmd_opts.no_progressbar_hiding: css += css_hide_progressbar @@ -1865,9 +1867,9 @@ def load_javascript(raw_response): with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: javascript = f'' - jsdir = os.path.join(script_path, "javascript") - for filename in sorted(os.listdir(jsdir)): - with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: + scripts_list = modules.scripts.list_scripts("javascript", ".js") + for basedir, filename, path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: javascript += f"\n" if cmd_opts.theme is not None: @@ -1885,6 +1887,5 @@ def load_javascript(raw_response): gradio.routes.templates.TemplateResponse = template_response -reload_javascript = partial(load_javascript, - gradio.routes.templates.TemplateResponse) +reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) reload_javascript() -- cgit v1.2.3 From 50b5504401e50b6c94eba41b37fe212b2f27b792 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 14:04:14 +0300 Subject: remove parsing command line from devices.py --- modules/devices.py | 14 +++++--------- modules/lowvram.py | 9 ++++----- 2 files changed, 9 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 8a159282..dc1f3cdd 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -15,14 +15,10 @@ def extract_device_id(args, name): def get_optimal_device(): if torch.cuda.is_available(): - # CUDA device selection support: - if "shared" not in sys.modules: - commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop. - sys.argv += shlex.split(commandline_args) - device_id = extract_device_id(sys.argv, '--device-id') - else: - device_id = shared.cmd_opts.device_id - + from modules import shared + + device_id = shared.cmd_opts.device_id + if device_id is not None: cuda_device = f"cuda:{device_id}" return torch.device(cuda_device) @@ -49,7 +45,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() +device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 diff --git a/modules/lowvram.py b/modules/lowvram.py index 7eba1349..f327c3df 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -1,9 +1,8 @@ import torch -from modules.devices import get_optimal_device +from modules import devices module_in_gpu = None cpu = torch.device("cpu") -device = gpu = get_optimal_device() def send_everything_to_cpu(): @@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram): if module_in_gpu is not None: module_in_gpu.to(cpu) - module.to(gpu) + module.to(devices.device) module_in_gpu = module # see below for register_forward_pre_hook; @@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram): # send the model to GPU. Then put modules back. the modules will be in CPU. stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None - sd_model.to(device) + sd_model.to(devices.device) sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored # register hooks for those the first two models @@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram): # so that only one of them is in GPU at a time stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None - sd_model.model.to(device) + sd_model.model.to(devices.device) diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored # install hooks for bits of third model -- cgit v1.2.3 From 0e8ca8e7af05be22d7d2c07a47c3c7febe0f0ab6 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Sat, 22 Oct 2022 11:07:00 +0000 Subject: add dropout --- modules/hypernetworks/hypernetwork.py | 68 +++++++++++++++++++++-------------- modules/hypernetworks/ui.py | 10 +++--- modules/ui.py | 43 +++++++++++----------- 3 files changed, 70 insertions(+), 51 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 905cbeef..e493f366 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -1,47 +1,60 @@ +import csv import datetime import glob import html import os import sys import traceback -import tqdm -import csv +import modules.textual_inversion.dataset import torch - -from ldm.util import default -from modules import devices, shared, processing, sd_models -import torch -from torch import einsum +import tqdm from einops import rearrange, repeat -import modules.textual_inversion.dataset +from ldm.util import default +from modules import devices, processing, sd_models, shared from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler +from torch import einsum class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU, - "swish": torch.nn.Hardswish} - - def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): + activation_dict = { + "relu": torch.nn.ReLU, + "leakyrelu": torch.nn.LeakyReLU, + "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish, + } + + def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): super().__init__() assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - + assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'" + linears = [] for i in range(len(layer_structure) - 1): + + # Add a fully-connected layer linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - # if skip_first_layer because first parameters potentially contain negative values - # if i < 1: continue - if activation_func in HypernetworkModule.activation_dict: - linears.append(HypernetworkModule.activation_dict[activation_func]()) + + # Add an activation func + if activation_func == "linear": + pass + elif activation_func in self.activation_dict: + linears.append(self.activation_dict[activation_func]()) else: - print("Invalid key {} encountered as activation function!".format(activation_func)) - # if use_dropout: - # linears.append(torch.nn.Dropout(p=0.3)) + raise NotImplementedError( + "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'" + ) + + # Add dropout + if use_dropout: + linears.append(torch.nn.Dropout(p=0.3)) + + # Add layer normalization if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -93,7 +106,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): self.filename = None self.name = name self.layers = {} @@ -101,13 +114,14 @@ class Hypernetwork: self.sd_checkpoint = None self.sd_checkpoint_name = None self.layer_structure = layer_structure - self.add_layer_norm = add_layer_norm self.activation_func = activation_func + self.add_layer_norm = add_layer_norm + self.use_dropout = use_dropout for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), ) def weights(self): @@ -129,8 +143,9 @@ class Hypernetwork: state_dict['step'] = self.step state_dict['name'] = self.name state_dict['layer_structure'] = self.layer_structure - state_dict['is_layer_norm'] = self.add_layer_norm state_dict['activation_func'] = self.activation_func + state_dict['is_layer_norm'] = self.add_layer_norm + state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name @@ -144,8 +159,9 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) - self.add_layer_norm = state_dict.get('is_layer_norm', False) self.activation_func = state_dict.get('activation_func', None) + self.add_layer_norm = state_dict.get('is_layer_norm', False) + self.use_dropout = state_dict.get('use_dropout', False) for size, sd in state_dict.items(): if type(size) == int: diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 1a5a27d8..5f6f17b6 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -3,14 +3,13 @@ import os import re import gradio as gr - -import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess -from modules import sd_hijack, shared, devices +import modules.textual_inversion.textual_inversion +from modules import devices, sd_hijack, shared from modules.hypernetworks import hypernetwork -def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None): +def create_hypernetwork(name, enable_sizes, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" @@ -21,8 +20,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm name=name, enable_sizes=[int(x) for x in enable_sizes], layer_structure=layer_structure, - add_layer_norm=add_layer_norm, activation_func=activation_func, + add_layer_norm=add_layer_norm, + use_dropout=use_dropout, ) hypernet.save(fn) diff --git a/modules/ui.py b/modules/ui.py index 716f14b8..d4b32c05 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,43 +5,44 @@ import json import math import mimetypes import os +import platform import random +import subprocess as sp import sys import tempfile import time import traceback -import platform -import subprocess as sp from functools import partial, reduce +import gradio as gr +import gradio.routes +import gradio.utils import numpy as np +import piexif import torch from PIL import Image, PngImagePlugin -import piexif -import gradio as gr -import gradio.utils -import gradio.routes - -from modules import sd_hijack, sd_models, localization +from modules import localization, sd_hijack, sd_models from modules.paths import script_path -from modules.shared import opts, cmd_opts, restricted_opts +from modules.shared import cmd_opts, opts, restricted_opts + if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags -import modules.shared as shared -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.sd_hijack import model_hijack + +import modules.codeformer_model +import modules.generation_parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.images_history as img_his import modules.ldsr_model import modules.scripts -import modules.gfpgan_model -import modules.codeformer_model +import modules.shared as shared import modules.styles -import modules.generation_parameters_copypaste +import modules.textual_inversion.ui from modules import prompt_parser from modules.images import save_image -import modules.textual_inversion.ui -import modules.hypernetworks.ui -import modules.images_history as img_his +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -1223,8 +1224,9 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") + new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"]) new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") - new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"]) + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") with gr.Row(): with gr.Column(scale=3): @@ -1308,8 +1310,9 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_name, new_hypernetwork_sizes, new_hypernetwork_layer_structure, - new_hypernetwork_add_layer_norm, new_hypernetwork_activation_func, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From 1cd3ed7def40198f46d30f74dd37d2906ebdbaa6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 14:28:56 +0300 Subject: fix for extensions without style.css --- modules/ui.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 29986124..d8d52db1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1639,6 +1639,9 @@ Requested path was: {f} css = "" for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + with open(cssfile, "r", encoding="utf8") as file: css += file.read() + "\n" -- cgit v1.2.3 From 7fd90128eb6d1820045bfe2c2c1269661023a712 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 14:48:43 +0300 Subject: added a guard for hypernet training that will stop early if weights are getting no gradients --- modules/hypernetworks/hypernetwork.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 47d91ea5..46039a49 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -310,6 +310,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) + steps_without_grad = 0 + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: hypernetwork.step = i + ititial_step @@ -332,8 +334,17 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item() optimizer.zero_grad() + weights[0].grad = None loss.backward() + + if weights[0].grad is None: + steps_without_grad += 1 + else: + steps_without_grad = 0 + assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' + optimizer.step() + mean_loss = losses.mean() if torch.isnan(mean_loss): raise RuntimeError("Loss diverged.") -- cgit v1.2.3 From fccba4729db341a299db3343e3264fecd9459a07 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Sat, 22 Oct 2022 12:02:41 +0000 Subject: add an option to avoid dying relu --- modules/hypernetworks/hypernetwork.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b7a04038..3132a56c 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -32,7 +32,6 @@ class HypernetworkModule(torch.nn.Module): assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'" linears = [] for i in range(len(layer_structure) - 1): @@ -43,12 +42,13 @@ class HypernetworkModule(torch.nn.Module): # Add an activation func if activation_func == "linear" or activation_func is None: pass + # If ReLU, Skip adding it to the first layer to avoid dying ReLU + elif activation_func == "relu" and i < 1: + pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) else: - raise RuntimeError( - "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'" - ) + raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') # Add dropout if use_dropout: @@ -166,8 +166,8 @@ class Hypernetwork: for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func), - HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func), + HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), ) self.name = state_dict.get('name', self.name) -- cgit v1.2.3 From 7912acef725832debef58c4c7bf8ec22fb446c0b Mon Sep 17 00:00:00 2001 From: discus0434 Date: Sat, 22 Oct 2022 13:00:44 +0000 Subject: small fix --- modules/hypernetworks/hypernetwork.py | 12 +++++------- modules/ui.py | 1 - 2 files changed, 5 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3132a56c..7d12e0ff 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -42,22 +42,20 @@ class HypernetworkModule(torch.nn.Module): # Add an activation func if activation_func == "linear" or activation_func is None: pass - # If ReLU, Skip adding it to the first layer to avoid dying ReLU - elif activation_func == "relu" and i < 1: - pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) else: raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') - # Add dropout - if use_dropout: - linears.append(torch.nn.Dropout(p=0.3)) - # Add layer normalization if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) + # Add dropout + if use_dropout: + p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2 + linears.append(torch.nn.Dropout(p=p)) + self.linear = torch.nn.Sequential(*linears) if state_dict is not None: diff --git a/modules/ui.py b/modules/ui.py index cd118552..eca887ca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1244,7 +1244,6 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") - new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"]) with gr.Row(): with gr.Column(scale=3): -- cgit v1.2.3 From 6a4fa73a38935a18779ce1809892730fd1572bee Mon Sep 17 00:00:00 2001 From: discus0434 Date: Sat, 22 Oct 2022 13:44:39 +0000 Subject: small fix --- modules/hypernetworks/hypernetwork.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3372aae2..3bc71ee5 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -51,10 +51,9 @@ class HypernetworkModule(torch.nn.Module): if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - # Add dropout - if use_dropout: - p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2 - linears.append(torch.nn.Dropout(p=p)) + # Add dropout expect last layer + if use_dropout and i < len(layer_structure) - 3: + linears.append(torch.nn.Dropout(p=0.3)) self.linear = torch.nn.Sequential(*linears) -- cgit v1.2.3 From d37cfffd537cd29309afbcb192c4f979995c6a34 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 19:18:56 +0300 Subject: added callback for creating new settings in extensions --- modules/script_callbacks.py | 11 +++++++++++ modules/shared.py | 19 +++++++++++++++++-- modules/ui.py | 6 +++++- 3 files changed, 33 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 866b7acd..1270e50f 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,6 +1,7 @@ callbacks_model_loaded = [] callbacks_ui_tabs = [] +callbacks_ui_settings = [] def clear_callbacks(): @@ -22,6 +23,11 @@ def ui_tabs_callback(): return res +def ui_settings_callback(): + for callback in callbacks_ui_settings: + callback() + + def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" @@ -40,3 +46,8 @@ def on_ui_tabs(callback): """ callbacks_ui_tabs.append(callback) + +def on_ui_settings(callback): + """register a function to be called before UI settingsare populated; add your settings + by using shared.opts.add_option(shared.OptionInfo(...)) """ + callbacks_ui_settings.append(callback) diff --git a/modules/shared.py b/modules/shared.py index 5d83971e..d9cb65ef 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -165,13 +165,13 @@ def realesrgan_models_names(): class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): self.default = default self.label = label self.component = component self.component_args = component_args self.onchange = onchange - self.section = None + self.section = section self.refresh = refresh @@ -327,6 +327,7 @@ options_templates.update(options_section(('images-history', "Images Browser"), { })) + class Options: data = None data_labels = options_templates @@ -389,6 +390,20 @@ class Options: d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} return json.dumps(d) + def add_option(self, key, info): + self.data_labels[key] = info + + def reorder(self): + """reorder settings so that all items related to section always go together""" + + section_ids = {} + settings_items = self.data_labels.items() + for k, item in settings_items: + if item.section not in section_ids: + section_ids[item.section] = len(section_ids) + + self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])} + opts = Options() if os.path.exists(config_filename): diff --git a/modules/ui.py b/modules/ui.py index d8d52db1..2849b111 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1461,6 +1461,9 @@ def create_ui(wrap_gradio_gpu_call): components = [] component_dict = {} + script_callbacks.ui_settings_callback() + opts.reorder() + def open_folder(f): if not os.path.exists(f): print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') @@ -1564,7 +1567,8 @@ Requested path was: {f} previous_section = item.section - gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='

{}

'.format(item.section[1])) + elem_id, text = item.section + gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) if k in quicksettings_names: quicksettings_list.append((i, k, item)) -- cgit v1.2.3 From dbc8ab65f6d496459a76547776b656c96ad1350d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 19:19:17 +0300 Subject: typo --- modules/script_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 1270e50f..5bcccd67 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -48,6 +48,6 @@ def on_ui_tabs(callback): def on_ui_settings(callback): - """register a function to be called before UI settingsare populated; add your settings + """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ callbacks_ui_settings.append(callback) -- cgit v1.2.3 From 72383abacdc6a101704a6f73758ce4d0bb68c9d1 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Sat, 22 Oct 2022 16:50:07 +0200 Subject: Deepdanbooru linux fix --- modules/deepbooru.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 8914662d..3c34ab7c 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -50,7 +50,8 @@ def create_deepbooru_process(threshold, deepbooru_opts): the tags. """ from modules import shared # prevents circular reference - shared.deepbooru_process_manager = multiprocessing.Manager() + context = multiprocessing.get_context("spawn") + shared.deepbooru_process_manager = context.Manager() shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() shared.deepbooru_process_return["value"] = -1 -- cgit v1.2.3 From e38625011cd4955da4bc67fe95d1d0f4c0c53899 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Sat, 22 Oct 2022 16:56:52 +0200 Subject: fix part2 --- modules/deepbooru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 3c34ab7c..8bbc90a4 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -55,7 +55,7 @@ def create_deepbooru_process(threshold, deepbooru_opts): shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts)) + shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts)) shared.deepbooru_process.start() -- cgit v1.2.3 From 324c7c732dd9afc3d4c397c354797ae5d655b514 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 20:09:37 +0300 Subject: record First pass size as 0x0 for #3328 --- modules/processing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 372489f7..27c669b0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -524,6 +524,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 + self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" + if self.firstphase_width == 0 or self.firstphase_height == 0: desired_pixel_count = 512 * 512 actual_pixel_count = self.width * self.height @@ -545,7 +547,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): firstphase_width_truncated = self.firstphase_height * self.width / self.height firstphase_height_truncated = self.firstphase_height - self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f -- cgit v1.2.3 From 0df94d3fcf9d1fc47c4d39039352a3d5b3380c1f Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Sat, 22 Oct 2022 12:59:21 -0400 Subject: fix aesthetic gradients doing nothing after loading a different model --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index f9b3063d..49dc3238 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -236,12 +236,11 @@ def load_model(checkpoint_info=None): sd_model.to(shared.device) sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) sd_model.eval() shared.sd_model = sd_model - script_callbacks.model_loaded_callback(sd_model) - print(f"Model loaded.") return sd_model @@ -268,6 +267,7 @@ def reload_model_weights(sd_model, info=None): load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) -- cgit v1.2.3 From 321bacc6a9eaf4a25f31279f288fa752be507a20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 20:15:12 +0300 Subject: call model_loaded_callback after setting shared.sd_model in case scripts refer to it using that --- modules/sd_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 49dc3238..e697bb72 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -236,11 +236,12 @@ def load_model(checkpoint_info=None): sd_model.to(shared.device) sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) sd_model.eval() shared.sd_model = sd_model + script_callbacks.model_loaded_callback(sd_model) + print(f"Model loaded.") return sd_model -- cgit v1.2.3 From 24694e5983d0944b901892cb101878e6dec89a20 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 01:57:58 +0900 Subject: Update hypernetwork.py --- modules/hypernetworks/hypernetwork.py | 55 ++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3bc71ee5..81132be4 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum +from statistics import stdev, mean class HypernetworkModule(torch.nn.Module): multiplier = 1.0 @@ -268,6 +269,32 @@ def stack_conds(conds): return torch.stack(conds) +def log_statistics(loss_info:dict, key, value): + if key not in loss_info: + loss_info[key] = [value] + else: + loss_info[key].append(value) + if len(loss_info) > 1024: + loss_info.pop(0) + + +def statistics(data): + total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})" + recent_data = data[-32:] + recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})" + return total_information, recent_information + + +def report_statistics(loss_info:dict): + keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) + for key in keys: + info, recent = statistics(loss_info[key]) + print("Loss statistics for file " + key) + print(info) + print(recent) + + + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -310,7 +337,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log for weight in weights: weight.requires_grad = True - losses = torch.zeros((32,)) + size = len(ds.indexes) + loss_dict = {} + losses = torch.zeros((size,)) + previous_mean_loss = 0 + print("Mean loss of {} elements".format(size)) last_saved_file = "" last_saved_image = "" @@ -329,7 +360,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: hypernetwork.step = i + ititial_step - + if loss_dict and i % size == 0: + previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict) + scheduler.apply(optimizer, hypernetwork.step) if scheduler.finished: break @@ -346,7 +379,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log del c losses[hypernetwork.step % losses.shape[0]] = loss.item() - + for entry in entries: + log_statistics(loss_dict, entry.filename, loss.item()) + optimizer.zero_grad() weights[0].grad = None loss.backward() @@ -359,10 +394,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log optimizer.step() - mean_loss = losses.mean() - if torch.isnan(mean_loss): + if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): raise RuntimeError("Loss diverged.") - pbar.set_description(f"loss: {mean_loss:.7f}") + pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}") if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. @@ -371,7 +405,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log hypernetwork.save(last_saved_file) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { - "loss": f"{mean_loss:.7f}", + "loss": f"{previous_mean_loss:.7f}", "learn_rate": scheduler.learn_rate }) @@ -420,14 +454,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = f"""

-Loss: {mean_loss:.7f}
+Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
Last prompt: {html.escape(entries[0].cond_text)}
Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - + + report_statistics(loss_dict) checkpoint = sd_models.select_checkpoint() hypernetwork.sd_checkpoint = checkpoint.hash @@ -438,5 +473,3 @@ Last saved image: {html.escape(last_saved_image)}
hypernetwork.save(filename) return hypernetwork, filename - - -- cgit v1.2.3 From 4fdb53c1e9962507fc8336dad9a0fabfe6c418c0 Mon Sep 17 00:00:00 2001 From: Unnoen Date: Wed, 19 Oct 2022 21:38:10 +1100 Subject: Generate grid preview for progress image --- modules/sd_samplers.py | 26 +++++++++++++++++++++++++- modules/shared.py | 1 + modules/ui.py | 5 ++++- 3 files changed, 30 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index f58a29b9..74a480e5 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,7 +7,7 @@ import inspect import k_diffusion.sampling import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from modules import prompt_parser, devices, processing +from modules import prompt_parser, devices, processing, images from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -89,6 +89,30 @@ def sample_to_image(samples): x_sample = x_sample.astype(np.uint8) return Image.fromarray(x_sample) +def samples_to_image_grid(samples): + progress_images = [] + for i in range(len(samples)): + # Decode the samples individually to reduce VRAM usage at the cost of a bit of speed. + x_sample = processing.decode_first_stage(shared.sd_model, samples[i:i+1])[0] + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + progress_images.append(Image.fromarray(x_sample)) + + return images.image_grid(progress_images) + +def samples_to_image_grid_combined(samples): + progress_images = [] + # Decode all samples at once to increase speed at the cost of VRAM usage. + x_samples = processing.decode_first_stage(shared.sd_model, samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + progress_images.append(Image.fromarray(x_sample)) + + return images.image_grid(progress_images) def store_latent(decoded): state.current_latent = decoded diff --git a/modules/shared.py b/modules/shared.py index d9cb65ef..95d6e225 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -294,6 +294,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), + "progress_decode_combined": OptionInfo(False, "Decode all progress images at once. (Slighty speeds up progress generation but consumes significantly more VRAM with large batches.)"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), diff --git a/modules/ui.py b/modules/ui.py index 56c233ab..de0abc7e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -318,7 +318,10 @@ def check_progress_call(id_part): if shared.parallel_processing_allowed: if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None: - shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent) + if opts.progress_decode_combined: + shared.state.current_image = modules.sd_samplers.samples_to_image_grid_combined(shared.state.current_latent) + else: + shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent) shared.state.current_image_sampling_step = shared.state.sampling_step image = shared.state.current_image -- cgit v1.2.3 From d213d6ca6f90094cb45c11e2f3cb37d25a8d1f94 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 20:48:13 +0300 Subject: removed the option to use 2x more memory when generating previews added an option to always only show one image in previews removed duplicate code --- modules/sd_samplers.py | 35 ++++++++++------------------------- modules/shared.py | 2 +- modules/ui.py | 6 +++--- 3 files changed, 14 insertions(+), 29 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 74a480e5..0b408a70 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -71,6 +71,7 @@ sampler_extra_params = { 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], } + def setup_img2img_steps(p, steps=None): if opts.img2img_fix_steps or steps is not None: steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 @@ -82,37 +83,21 @@ def setup_img2img_steps(p, steps=None): return steps, t_enc -def sample_to_image(samples): - x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0] +def single_sample_to_image(sample): + x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) return Image.fromarray(x_sample) + +def sample_to_image(samples): + return single_sample_to_image(samples[0]) + + def samples_to_image_grid(samples): - progress_images = [] - for i in range(len(samples)): - # Decode the samples individually to reduce VRAM usage at the cost of a bit of speed. - x_sample = processing.decode_first_stage(shared.sd_model, samples[i:i+1])[0] - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - progress_images.append(Image.fromarray(x_sample)) - - return images.image_grid(progress_images) - -def samples_to_image_grid_combined(samples): - progress_images = [] - # Decode all samples at once to increase speed at the cost of VRAM usage. - x_samples = processing.decode_first_stage(shared.sd_model, samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - for x_sample in x_samples: - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - progress_images.append(Image.fromarray(x_sample)) - - return images.image_grid(progress_images) + return images.image_grid([single_sample_to_image(sample) for sample in samples]) + def store_latent(decoded): state.current_latent = decoded diff --git a/modules/shared.py b/modules/shared.py index 95d6e225..25bfc895 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -294,7 +294,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), - "progress_decode_combined": OptionInfo(False, "Decode all progress images at once. (Slighty speeds up progress generation but consumes significantly more VRAM with large batches.)"), + "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), diff --git a/modules/ui.py b/modules/ui.py index de0abc7e..ffa14cac 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -318,10 +318,10 @@ def check_progress_call(id_part): if shared.parallel_processing_allowed: if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None: - if opts.progress_decode_combined: - shared.state.current_image = modules.sd_samplers.samples_to_image_grid_combined(shared.state.current_latent) - else: + if opts.show_progress_grid: shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent) + else: + shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent) shared.state.current_image_sampling_step = shared.state.sampling_step image = shared.state.current_image -- cgit v1.2.3 From be748e8b086bd9834d08bdd9160649a5e7700af7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 22:05:22 +0300 Subject: add --freeze-settings commandline argument to disable changing settings --- modules/shared.py | 1 + modules/ui.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 25bfc895..b55371d3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -64,6 +64,7 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) +parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) diff --git a/modules/ui.py b/modules/ui.py index ffa14cac..2311572c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -580,6 +580,9 @@ def apply_setting(key, value): if value is None: return gr.update() + if shared.cmd_opts.freeze_settings: + return gr.update() + # dont allow model to be swapped when model hash exists in prompt if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: return gr.update() @@ -1501,6 +1504,8 @@ Requested path was: {f} def run_settings(*args): changed = 0 + assert not shared.cmd_opts.freeze_settings, "changing settings is disabled" + for key, value, comp in zip(opts.data_labels.keys(), args, components): if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default): return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson() @@ -1530,6 +1535,8 @@ Requested path was: {f} return f'{changed} settings changed.', opts.dumpjson() def run_settings_single(value, key): + assert not shared.cmd_opts.freeze_settings, "changing settings is disabled" + if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() @@ -1582,7 +1589,7 @@ Requested path was: {f} elem_id, text = item.section gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) - if k in quicksettings_names: + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: quicksettings_list.append((i, k, item)) components.append(dummy_component) else: @@ -1615,7 +1622,7 @@ Requested path was: {f} def reload_scripts(): modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page + reload_javascript() # need to refresh the html page reload_script_bodies.click( fn=reload_scripts, -- cgit v1.2.3 From ca5a9e79dc28eeaa3a161427a82e34703bf15765 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 22:06:54 +0300 Subject: fix for img2img color correction in a batch #3218 --- modules/processing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 27c669b0..b1877b80 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -403,8 +403,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if (len(prompts) == 0): break - #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) - #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) @@ -716,6 +714,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0) if self.overlay_images is not None: self.overlay_images = self.overlay_images * self.batch_size + + if self.color_corrections is not None and len(self.color_corrections) == 1: + self.color_corrections = self.color_corrections * self.batch_size + elif len(imgs) <= self.batch_size: self.batch_size = len(imgs) batch_images = np.array(imgs) -- cgit v1.2.3 From 48dbf99e84045ee7af55bc5b1b86492a240e631e Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 04:17:16 +0900 Subject: Allow tracking real-time loss Someone had 6000 images in their dataset, and it was shown as 0, which was confusing. This will allow tracking real time dataset-average loss for registered objects. --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 81132be4..99fd0f8f 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -360,7 +360,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: hypernetwork.step = i + ititial_step - if loss_dict and i % size == 0: + if len(loss_dict) > 0: previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict) scheduler.apply(optimizer, hypernetwork.step) -- cgit v1.2.3 From ce42879438bf2dbd76b5b346be656292e42ffb2b Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sat, 22 Oct 2022 14:53:37 -0500 Subject: fix js func signature and not forget to initialize confirmation var to prevent exception upon cancelling confirmation --- modules/ui.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 25aeba3b..e58f040e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,10 +429,12 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -def clear_prompt(_prompt, _prompt_neg, confirmed, _token_counter): +def clear_prompt(prompt, _prompt_neg, confirmed, _token_counter): """Given confirmation from a user on the client-side, go ahead with clearing prompt""" if confirmed: return ["", "", confirmed, update_token_counter("", 1)] + else: + return [prompt, _prompt_neg, confirmed, _token_counter] def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter): -- cgit v1.2.3 From 1b4d04737ac513cbd55958bb60a4f85166f3484b Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 22 Oct 2022 20:13:16 -0300 Subject: Remove unused imports --- modules/api/api.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5b0c934e..a5136b4b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,11 +1,9 @@ from modules.api.processing import StableDiffusionProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, process_images from modules.sd_samplers import all_samplers -from modules.extras import run_pnginfo import modules.shared as shared import uvicorn -from fastapi import Body, APIRouter, HTTPException -from fastapi.responses import JSONResponse +from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field, Json import json import io @@ -18,7 +16,6 @@ class TextToImageResponse(BaseModel): parameters: Json info: Json - class Api: def __init__(self, app, queue_lock): self.router = APIRouter() -- cgit v1.2.3 From b02926df1393df311db734af149fb9faf4389cbe Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 22 Oct 2022 20:24:04 -0300 Subject: Moved moodels to their own file and extracted base64 conversion to its own function --- modules/api/api.py | 17 ++++++----------- modules/api/models.py | 8 ++++++++ 2 files changed, 14 insertions(+), 11 deletions(-) create mode 100644 modules/api/models.py (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index a5136b4b..c17d7580 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -4,17 +4,17 @@ from modules.sd_samplers import all_samplers import modules.shared as shared import uvicorn from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, Field, Json import json import io import base64 +from modules.api.models import * sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -class TextToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: Json - info: Json +def img_to_base64(img): + buffer = io.BytesIO() + img.save(buffer, format="png") + return base64.b64encode(buffer.getvalue()) class Api: def __init__(self, app, queue_lock): @@ -41,15 +41,10 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = [] - for i in processed.images: - buffer = io.BytesIO() - i.save(buffer, format="png") - b64images.append(base64.b64encode(buffer.getvalue())) + b64images = list(map(img_to_base64, processed.images)) return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) - def img2imgapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py new file mode 100644 index 00000000..a7d247d8 --- /dev/null +++ b/modules/api/models.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel, Field, Json + +class TextToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: Json + info: Json + + \ No newline at end of file -- cgit v1.2.3 From 28e26c2bef217ae82eb9e980cceb3f67ef22e109 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 22 Oct 2022 23:13:32 -0300 Subject: Add "extra" single image operation - Separate extra modes into 3 endpoints so the user ddoesn't ahve to handle so many unused parameters. - Add response model for codumentation --- modules/api/api.py | 43 ++++++++++++++++++++++++++++++++++++++----- modules/api/models.py | 26 +++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index c17d7580..3b804373 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -8,20 +8,42 @@ import json import io import base64 from modules.api.models import * +from PIL import Image +from modules.extras import run_extras + +def upscaler_to_index(name: str): + try: + return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) + except: + raise HTTPException(status_code=400, detail="Upscaler not found") sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -def img_to_base64(img): +def img_to_base64(img: str): buffer = io.BytesIO() img.save(buffer, format="png") return base64.b64encode(buffer.getvalue()) +def base64_to_bytes(base64Img: str): + if "," in base64Img: + base64Img = base64Img.split(",")[1] + return io.BytesIO(base64.b64decode(base64Img)) + +def base64_to_images(base64Imgs: list[str]): + imgs = [] + for img in base64Imgs: + img = Image.open(base64_to_bytes(img)) + imgs.append(img) + return imgs + + class Api: def __init__(self, app, queue_lock): self.router = APIRouter() self.app = app self.queue_lock = queue_lock - self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) + self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -45,12 +67,23 @@ class Api: return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) - def img2imgapi(self): raise NotImplementedError - def extrasapi(self): - raise NotImplementedError + def extras_single_image_api(self, req: ExtrasSingleImageRequest): + upscaler1Index = upscaler_to_index(req.upscaler_1) + upscaler2Index = upscaler_to_index(req.upscaler_2) + + reqDict = vars(req) + reqDict.pop('upscaler_1') + reqDict.pop('upscaler_2') + + reqDict['image'] = base64_to_images([reqDict['image']])[0] + + with self.queue_lock: + result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") + + return ExtrasSingleImageResponse(image="data:image/png;base64,"+img_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index a7d247d8..dcf1ab54 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,8 +1,32 @@ from pydantic import BaseModel, Field, Json +from typing_extensions import Literal +from modules.shared import sd_upscalers class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: Json info: Json - \ No newline at end of file +class ExtrasBaseRequest(BaseModel): + resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.") + show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?") + gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") + codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") + codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") + upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") + upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?") + upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") + upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") + extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") + +class ExtraBaseResponse(BaseModel): + html_info_x: str + html_info: str + +class ExtrasSingleImageRequest(ExtrasBaseRequest): + image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") + +class ExtrasSingleImageResponse(ExtraBaseResponse): + image: str = Field(default=None, title="Image", description="The generated image in base64 format.") \ No newline at end of file -- cgit v1.2.3 From 1fbfc052eb529d8cf8ce5baf578bcf93d0280c29 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 23 Oct 2022 05:43:34 +0100 Subject: Update hypernetwork.py --- modules/hypernetworks/hypernetwork.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 99fd0f8f..98a7b62e 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -288,10 +288,13 @@ def statistics(data): def report_statistics(loss_info:dict): keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) for key in keys: - info, recent = statistics(loss_info[key]) - print("Loss statistics for file " + key) - print(info) - print(recent) + try: + print("Loss statistics for file " + key) + info, recent = statistics(loss_info[key]) + print(info) + print(recent) + except Exception as e: + print(e) -- cgit v1.2.3 From a7c213d0f5ebb10722629b8490a5863f9ce6c4fa Mon Sep 17 00:00:00 2001 From: Stephen Date: Fri, 21 Oct 2022 19:27:40 -0400 Subject: [API][Feature] - Add img2img API endpoint --- modules/api/api.py | 58 +++++++++++++++++++++++++++++++++++++++++++---- modules/api/processing.py | 11 +++++++-- modules/processing.py | 2 +- 3 files changed, 63 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5b0c934e..a04f2428 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,5 +1,5 @@ -from modules.api.processing import StableDiffusionProcessingAPI -from modules.processing import StableDiffusionProcessingTxt2Img, process_images +from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers from modules.extras import run_pnginfo import modules.shared as shared @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, Json import json import io import base64 +from PIL import Image sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) @@ -18,6 +19,11 @@ class TextToImageResponse(BaseModel): parameters: Json info: Json +class ImageToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: Json + info: Json + class Api: def __init__(self, app, queue_lock): @@ -25,8 +31,9 @@ class Api: self.app = app self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) - def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) if sampler_index is None: @@ -54,8 +61,49 @@ class Api: - def img2imgapi(self): - raise NotImplementedError + def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): + sampler_index = sampler_to_index(img2imgreq.sampler_index) + + if sampler_index is None: + raise HTTPException(status_code=404, detail="Sampler not found") + + + init_images = img2imgreq.init_images + if init_images is None: + raise HTTPException(status_code=404, detail="Init image not found") + + + populate = img2imgreq.copy(update={ # Override __init__ params + "sd_model": shared.sd_model, + "sampler_index": sampler_index[0], + "do_not_save_samples": True, + "do_not_save_grid": True + } + ) + p = StableDiffusionProcessingImg2Img(**vars(populate)) + + imgs = [] + for img in init_images: + # if has a comma, deal with prefix + if "," in img: + img = img.split(",")[1] + # convert base64 to PIL image + img = base64.b64decode(img) + img = Image.open(io.BytesIO(img)) + imgs = [img] * p.batch_size + + p.init_images = imgs + # Override object param + with self.queue_lock: + processed = process_images(p) + + b64images = [] + for i in processed.images: + buffer = io.BytesIO() + i.save(buffer, format="png") + b64images.append(base64.b64encode(buffer.getvalue())) + + return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) def extrasapi(self): raise NotImplementedError diff --git a/modules/api/processing.py b/modules/api/processing.py index 4c541241..9f1d65c0 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -1,7 +1,8 @@ +from array import array from inflection import underscore from typing import Any, Dict, Optional from pydantic import BaseModel, Field, create_model -from modules.processing import StableDiffusionProcessingTxt2Img +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img import inspect @@ -92,8 +93,14 @@ class PydanticModelGenerator: DynamicModel.__config__.allow_mutation = True return DynamicModel -StableDiffusionProcessingAPI = PydanticModelGenerator( +StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, [{"key": "sampler_index", "type": str, "default": "Euler"}] +).generate_model() + +StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingImg2Img", + StableDiffusionProcessingImg2Img, + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}] ).generate_model() \ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index b1877b80..1557ed8c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -623,7 +623,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None - def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs): + def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: str=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs): super().__init__(**kwargs) self.init_images = init_images -- cgit v1.2.3 From 9e1a8b7734a2881451a2efbf80def011ea41ba49 Mon Sep 17 00:00:00 2001 From: Stephen Date: Sat, 22 Oct 2022 15:42:00 -0400 Subject: non-implemented mask with any type --- modules/api/api.py | 4 ++++ modules/api/processing.py | 2 +- modules/processing.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index a04f2428..3df6ff96 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -72,6 +72,10 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") + mask = img2imgreq.mask + if mask: + raise HTTPException(status_code=400, detail="Mask not supported yet") + populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, diff --git a/modules/api/processing.py b/modules/api/processing.py index 9f1d65c0..f551fa35 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -102,5 +102,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] ).generate_model() \ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index 1557ed8c..ff83023c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -623,7 +623,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None - def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: str=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs): + def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs): super().__init__(**kwargs) self.init_images = init_images -- cgit v1.2.3 From 5dc0739ecdc1ade8fcf4eb77f2a503ef12489f32 Mon Sep 17 00:00:00 2001 From: Stephen Date: Sat, 22 Oct 2022 17:10:28 -0400 Subject: working mask --- modules/api/api.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3df6ff96..3caa83a4 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -33,6 +33,14 @@ class Api: self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) + def __base64_to_image(self, base64_string): + # if has a comma, deal with prefix + if "," in base64_string: + base64_string = base64_string.split(",")[1] + imgdata = base64.b64decode(base64_string) + # convert base64 to PIL image + return Image.open(io.BytesIO(imgdata)) + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -74,26 +82,22 @@ class Api: mask = img2imgreq.mask if mask: - raise HTTPException(status_code=400, detail="Mask not supported yet") + mask = self.__base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, "sampler_index": sampler_index[0], "do_not_save_samples": True, - "do_not_save_grid": True + "do_not_save_grid": True, + "mask": mask } ) p = StableDiffusionProcessingImg2Img(**vars(populate)) imgs = [] for img in init_images: - # if has a comma, deal with prefix - if "," in img: - img = img.split(",")[1] - # convert base64 to PIL image - img = base64.b64decode(img) - img = Image.open(io.BytesIO(img)) + img = self.__base64_to_image(img) imgs = [img] * p.batch_size p.init_images = imgs -- cgit v1.2.3 From 1be5933ba21a3badec42b7b2753d626f849b609d Mon Sep 17 00:00:00 2001 From: captin411 Date: Sun, 23 Oct 2022 04:11:07 -0700 Subject: auto cropping now works with non square crops --- modules/textual_inversion/autocrop.py | 509 ++++++++++++++++++---------------- 1 file changed, 269 insertions(+), 240 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 5a551c25..b2f9241c 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -1,241 +1,270 @@ -import cv2 -from collections import defaultdict -from math import log, sqrt -import numpy as np -from PIL import Image, ImageDraw - -GREEN = "#0F0" -BLUE = "#00F" -RED = "#F00" - - -def crop_image(im, settings): - """ Intelligently crop an image to the subject matter """ - if im.height > im.width: - im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) - elif im.width > im.height: - im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) - else: - im = im.resize((settings.crop_width, settings.crop_height)) - - if im.height == im.width: - return im - - focus = focal_point(im, settings) - - # take the focal point and turn it into crop coordinates that try to center over the focal - # point but then get adjusted back into the frame - y_half = int(settings.crop_height / 2) - x_half = int(settings.crop_width / 2) - - x1 = focus.x - x_half - if x1 < 0: - x1 = 0 - elif x1 + settings.crop_width > im.width: - x1 = im.width - settings.crop_width - - y1 = focus.y - y_half - if y1 < 0: - y1 = 0 - elif y1 + settings.crop_height > im.height: - y1 = im.height - settings.crop_height - - x2 = x1 + settings.crop_width - y2 = y1 + settings.crop_height - - crop = [x1, y1, x2, y2] - - if settings.annotate_image: - d = ImageDraw.Draw(im) - rect = list(crop) - rect[2] -= 1 - rect[3] -= 1 - d.rectangle(rect, outline=GREEN) - if settings.destop_view_image: - im.show() - - return im.crop(tuple(crop)) - -def focal_point(im, settings): - corner_points = image_corner_points(im, settings) - entropy_points = image_entropy_points(im, settings) - face_points = image_face_points(im, settings) - - total_points = len(corner_points) + len(entropy_points) + len(face_points) - - corner_weight = settings.corner_points_weight - entropy_weight = settings.entropy_points_weight - face_weight = settings.face_points_weight - - weight_pref_total = corner_weight + entropy_weight + face_weight - - # weight things - pois = [] - if weight_pref_total == 0 or total_points == 0: - return pois - - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] - ) - - average_point = poi_average(pois, settings) - - if settings.annotate_image: - d = ImageDraw.Draw(im) - for f in face_points: - d.rectangle(f.bounding(f.size), outline=RED) - for f in entropy_points: - d.rectangle(f.bounding(30), outline=BLUE) - for poi in pois: - w = max(4, 4 * 0.5 * sqrt(poi.weight)) - d.ellipse(poi.bounding(w), fill=BLUE) - d.ellipse(average_point.bounding(25), outline=GREEN) - - return average_point - - -def image_face_points(im, settings): - np_im = np.array(im) - gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - - tries = [ - [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] - ] - - for t in tries: - # print(t[0]) - classifier = cv2.CascadeClassifier(t[0]) - minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side - try: - faces = classifier.detectMultiScale(gray, scaleFactor=1.1, - minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - except: - continue - - if len(faces) > 0: - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] - return [] - - -def image_corner_points(im, settings): - grayscale = im.convert("L") - - # naive attempt at preventing focal points from collecting at watermarks near the bottom - gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") - - np_im = np.array(grayscale) - - points = cv2.goodFeaturesToTrack( - np_im, - maxCorners=100, - qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.07, - useHarrisDetector=False, - ) - - if points is None: - return [] - - focal_points = [] - for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y, size=4)) - - return focal_points - - -def image_entropy_points(im, settings): - landscape = im.height < im.width - portrait = im.height > im.width - if landscape: - move_idx = [0, 2] - move_max = im.size[0] - elif portrait: - move_idx = [1, 3] - move_max = im.size[1] - else: - return [] - - e_max = 0 - crop_current = [0, 0, settings.crop_width, settings.crop_height] - crop_best = crop_current - while crop_current[move_idx[1]] < move_max: - crop = im.crop(tuple(crop_current)) - e = image_entropy(crop) - - if (e > e_max): - e_max = e - crop_best = list(crop_current) - - crop_current[move_idx[0]] += 4 - crop_current[move_idx[1]] += 4 - - x_mid = int(crop_best[0] + settings.crop_width/2) - y_mid = int(crop_best[1] + settings.crop_height/2) - - return [PointOfInterest(x_mid, y_mid, size=25)] - - -def image_entropy(im): - # greyscale image entropy - # band = np.asarray(im.convert("L")) - band = np.asarray(im.convert("1"), dtype=np.uint8) - hist, _ = np.histogram(band, bins=range(0, 256)) - hist = hist[hist > 0] - return -np.log2(hist / hist.sum()).sum() - - -def poi_average(pois, settings): - weight = 0.0 - x = 0.0 - y = 0.0 - for poi in pois: - weight += poi.weight - x += poi.x * poi.weight - y += poi.y * poi.weight - avg_x = round(x / weight) - avg_y = round(y / weight) - - return PointOfInterest(avg_x, avg_y) - - -class PointOfInterest: - def __init__(self, x, y, weight=1.0, size=10): - self.x = x - self.y = y - self.weight = weight - self.size = size - - def bounding(self, size): - return [ - self.x - size//2, - self.y - size//2, - self.x + size//2, - self.y + size//2 - ] - - -class Settings: - def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): - self.crop_width = crop_width - self.crop_height = crop_height - self.corner_points_weight = corner_points_weight - self.entropy_points_weight = entropy_points_weight - self.face_points_weight = entropy_points_weight - self.annotate_image = annotate_image +import cv2 +from collections import defaultdict +from math import log, sqrt +import numpy as np +from PIL import Image, ImageDraw + +GREEN = "#0F0" +BLUE = "#00F" +RED = "#F00" + + +def crop_image(im, settings): + """ Intelligently crop an image to the subject matter """ + + scale_by = 1 + if is_landscape(im.width, im.height): + scale_by = settings.crop_height / im.height + elif is_portrait(im.width, im.height): + scale_by = settings.crop_width / im.width + elif is_square(im.width, im.height): + if is_square(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_landscape(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_portrait(settings.crop_width, settings.crop_height): + scale_by = settings.crop_height / im.height + + im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) + + if im.width == settings.crop_width and im.height == settings.crop_height: + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = [0, 0, im.width, im.height] + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + return im + + focus = focal_point(im, settings) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(settings.crop_height / 2) + x_half = int(settings.crop_width / 2) + + x1 = focus.x - x_half + if x1 < 0: + x1 = 0 + elif x1 + settings.crop_width > im.width: + x1 = im.width - settings.crop_width + + y1 = focus.y - y_half + if y1 < 0: + y1 = 0 + elif y1 + settings.crop_height > im.height: + y1 = im.height - settings.crop_height + + x2 = x1 + settings.crop_width + y2 = y1 + settings.crop_height + + crop = [x1, y1, x2, y2] + + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = list(crop) + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + + return im.crop(tuple(crop)) + +def focal_point(im, settings): + corner_points = image_corner_points(im, settings) + entropy_points = image_entropy_points(im, settings) + face_points = image_face_points(im, settings) + + total_points = len(corner_points) + len(entropy_points) + len(face_points) + + corner_weight = settings.corner_points_weight + entropy_weight = settings.entropy_points_weight + face_weight = settings.face_points_weight + + weight_pref_total = corner_weight + entropy_weight + face_weight + + # weight things + pois = [] + if weight_pref_total == 0 or total_points == 0: + return pois + + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] + ) + + average_point = poi_average(pois, settings) + + if settings.annotate_image: + d = ImageDraw.Draw(im) + for f in face_points: + d.rectangle(f.bounding(f.size), outline=RED) + for f in entropy_points: + d.rectangle(f.bounding(30), outline=BLUE) + for poi in pois: + w = max(4, 4 * 0.5 * sqrt(poi.weight)) + d.ellipse(poi.bounding(w), fill=BLUE) + d.ellipse(average_point.bounding(25), outline=GREEN) + + return average_point + + +def image_face_points(im, settings): + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + + tries = [ + [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] + ] + + for t in tries: + # print(t[0]) + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + except: + continue + + if len(faces) > 0: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] + return [] + + +def image_corner_points(im, settings): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=100, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.07, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4)) + + return focal_points + + +def image_entropy_points(im, settings): + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] + else: + return [] + + e_max = 0 + crop_current = [0, 0, settings.crop_width, settings.crop_height] + crop_best = crop_current + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e > e_max): + e_max = e + crop_best = list(crop_current) + + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + settings.crop_width/2) + y_mid = int(crop_best[1] + settings.crop_height/2) + + return [PointOfInterest(x_mid, y_mid, size=25)] + + +def image_entropy(im): + # greyscale image entropy + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + + +def poi_average(pois, settings): + weight = 0.0 + x = 0.0 + y = 0.0 + for poi in pois: + weight += poi.weight + x += poi.x * poi.weight + y += poi.y * poi.weight + avg_x = round(x / weight) + avg_y = round(y / weight) + + return PointOfInterest(avg_x, avg_y) + + +def is_landscape(w, h): + return w > h + + +def is_portrait(w, h): + return h > w + + +def is_square(w, h): + return w == h + + +class PointOfInterest: + def __init__(self, x, y, weight=1.0, size=10): + self.x = x + self.y = y + self.weight = weight + self.size = size + + def bounding(self, size): + return [ + self.x - size//2, + self.y - size//2, + self.x + size//2, + self.y + size//2 + ] + + +class Settings: + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): + self.crop_width = crop_width + self.crop_height = crop_height + self.corner_points_weight = corner_points_weight + self.entropy_points_weight = entropy_points_weight + self.face_points_weight = entropy_points_weight + self.annotate_image = annotate_image self.destop_view_image = False \ No newline at end of file -- cgit v1.2.3 From 0523704dade0508bf3ae0c8cb799b1ae332d449b Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 12:27:50 -0300 Subject: Update run_extras to use the temp filename In batch mode run_extras tries to preserve the original file name of the images. The problem is that this makes no sense since the user only gets a list of images in the UI, trying to manually save them shows that this images have random temp names. Also, trying to keep "orig_name" in the API is a hassle that adds complexity to the consuming UI since the client has to use (or emulate) an input (type=file) element in a form. Using the normal file name not only doesn't change the output and functionality in the original UI but also helps keep the API simple. --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 22c5a1c1..29ac312e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -33,7 +33,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ for img in image_folder: image = Image.open(img) imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) + imageNameArr.append(os.path.splitext(img.name)[0]) elif extras_mode == 2: assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' -- cgit v1.2.3 From 4ff852ffb50859f2eae75375cab94dd790a46886 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 13:07:59 -0300 Subject: Add batch processing "extras" endpoint --- modules/api/api.py | 25 +++++++++++++++++++++++-- modules/api/models.py | 15 ++++++++++++++- 2 files changed, 37 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3b804373..528134a8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -10,6 +10,7 @@ import base64 from modules.api.models import * from PIL import Image from modules.extras import run_extras +from gradio import processing_utils def upscaler_to_index(name: str): try: @@ -44,6 +45,7 @@ class Api: self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) + self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -78,12 +80,31 @@ class Api: reqDict.pop('upscaler_1') reqDict.pop('upscaler_2') - reqDict['image'] = base64_to_images([reqDict['image']])[0] + reqDict['image'] = processing_utils.decode_base64_to_file(reqDict['image']) with self.queue_lock: result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") - return ExtrasSingleImageResponse(image="data:image/png;base64,"+img_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) + + def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): + upscaler1Index = upscaler_to_index(req.upscaler_1) + upscaler2Index = upscaler_to_index(req.upscaler_2) + + reqDict = vars(req) + reqDict.pop('upscaler_1') + reqDict.pop('upscaler_2') + + reqDict['image_folder'] = list(map(processing_utils.decode_base64_to_file, reqDict['imageList'])) + reqDict.pop('imageList') + + with self.queue_lock: + result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=1, image="", input_dir="", output_dir="") + + return ExtrasBatchImagesResponse(images=list(map(processing_utils.encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + + def extras_folder_processing_api(self): + raise NotImplementedError def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index dcf1ab54..bbd0ef53 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -29,4 +29,17 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") class ExtrasSingleImageResponse(ExtraBaseResponse): - image: str = Field(default=None, title="Image", description="The generated image in base64 format.") \ No newline at end of file + image: str = Field(default=None, title="Image", description="The generated image in base64 format.") + +class SerializableImage(BaseModel): + path: str = Field(title="Path", description="The image's path ()") + +class ImageItem(BaseModel): + data: str = Field(title="image data") + name: str = Field(title="filename") + +class ExtrasBatchImagesRequest(ExtrasBaseRequest): + imageList: list[str] = Field(title="Images", description="List of images to work on. Must be Base64 strings") + +class ExtrasBatchImagesResponse(ExtraBaseResponse): + images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file -- cgit v1.2.3 From e0ca4dfbc10e0af8dfc4185e5e758f33fd2f0d81 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 15:13:37 -0300 Subject: Update endpoints to use gradio's own utils functions --- modules/api/api.py | 75 +++++++++++++++++++++++++-------------------------- modules/api/models.py | 4 +-- 2 files changed, 38 insertions(+), 41 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3f490ce2..3acb1f36 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -20,27 +20,27 @@ def upscaler_to_index(name: str): sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -def img_to_base64(img: str): - buffer = io.BytesIO() - img.save(buffer, format="png") - return base64.b64encode(buffer.getvalue()) - -def base64_to_bytes(base64Img: str): - if "," in base64Img: - base64Img = base64Img.split(",")[1] - return io.BytesIO(base64.b64decode(base64Img)) - -def base64_to_images(base64Imgs: list[str]): - imgs = [] - for img in base64Imgs: - img = Image.open(base64_to_bytes(img)) - imgs.append(img) - return imgs +# def img_to_base64(img: str): +# buffer = io.BytesIO() +# img.save(buffer, format="png") +# return base64.b64encode(buffer.getvalue()) + +# def base64_to_bytes(base64Img: str): +# if "," in base64Img: +# base64Img = base64Img.split(",")[1] +# return io.BytesIO(base64.b64decode(base64Img)) + +# def base64_to_images(base64Imgs: list[str]): +# imgs = [] +# for img in base64Imgs: +# img = Image.open(base64_to_bytes(img)) +# imgs.append(img) +# return imgs class ImageToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: Json - info: Json + parameters: dict + info: str class Api: @@ -49,17 +49,17 @@ class Api: self.app = app self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) - self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - def __base64_to_image(self, base64_string): - # if has a comma, deal with prefix - if "," in base64_string: - base64_string = base64_string.split(",")[1] - imgdata = base64.b64decode(base64_string) - # convert base64 to PIL image - return Image.open(io.BytesIO(imgdata)) + # def __base64_to_image(self, base64_string): + # # if has a comma, deal with prefix + # if "," in base64_string: + # base64_string = base64_string.split(",")[1] + # imgdata = base64.b64decode(base64_string) + # # convert base64 to PIL image + # return Image.open(io.BytesIO(imgdata)) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -79,11 +79,9 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = list(map(img_to_base64, processed.images)) - - return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) - + b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.info) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): sampler_index = sampler_to_index(img2imgreq.sampler_index) @@ -98,7 +96,7 @@ class Api: mask = img2imgreq.mask if mask: - mask = self.__base64_to_image(mask) + mask = processing_utils.decode_base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params @@ -113,7 +111,7 @@ class Api: imgs = [] for img in init_images: - img = self.__base64_to_image(img) + img = processing_utils.decode_base64_to_image(img) imgs = [img] * p.batch_size p.init_images = imgs @@ -121,13 +119,12 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = [] - for i in processed.images: - buffer = io.BytesIO() - i.save(buffer, format="png") - b64images.append(base64.b64encode(buffer.getvalue())) - - return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) + b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + # for i in processed.images: + # buffer = io.BytesIO() + # i.save(buffer, format="png") + # b64images.append(base64.b64encode(buffer.getvalue())) + return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info) def extras_single_image_api(self, req: ExtrasSingleImageRequest): upscaler1Index = upscaler_to_index(req.upscaler_1) diff --git a/modules/api/models.py b/modules/api/models.py index bbd0ef53..209f8af5 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -4,8 +4,8 @@ from modules.shared import sd_upscalers class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: Json - info: Json + parameters: str + info: str class ExtrasBaseRequest(BaseModel): resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.") -- cgit v1.2.3 From 866b36d705a338d299aba385788729d60f7d48c8 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 15:35:49 -0300 Subject: Move processing's models into models.py It didn't make sense to have two differente files for the same and "models" is a more descriptive name. --- modules/api/api.py | 57 ++++------------------- modules/api/models.py | 112 +++++++++++++++++++++++++++++++++++++++++++++- modules/api/processing.py | 106 ------------------------------------------- 3 files changed, 119 insertions(+), 156 deletions(-) delete mode 100644 modules/api/processing.py (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3acb1f36..20e85e82 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,16 +1,11 @@ -from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI -from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers -import modules.shared as shared import uvicorn +from gradio import processing_utils from fastapi import APIRouter, HTTPException -import json -import io -import base64 +import modules.shared as shared from modules.api.models import * -from PIL import Image +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +from modules.sd_samplers import all_samplers from modules.extras import run_extras -from gradio import processing_utils def upscaler_to_index(name: str): try: @@ -20,29 +15,6 @@ def upscaler_to_index(name: str): sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -# def img_to_base64(img: str): -# buffer = io.BytesIO() -# img.save(buffer, format="png") -# return base64.b64encode(buffer.getvalue()) - -# def base64_to_bytes(base64Img: str): -# if "," in base64Img: -# base64Img = base64Img.split(",")[1] -# return io.BytesIO(base64.b64decode(base64Img)) - -# def base64_to_images(base64Imgs: list[str]): -# imgs = [] -# for img in base64Imgs: -# img = Image.open(base64_to_bytes(img)) -# imgs.append(img) -# return imgs - -class ImageToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: dict - info: str - - class Api: def __init__(self, app, queue_lock): self.router = APIRouter() @@ -51,15 +23,7 @@ class Api: self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) - self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - - # def __base64_to_image(self, base64_string): - # # if has a comma, deal with prefix - # if "," in base64_string: - # base64_string = base64_string.split(",")[1] - # imgdata = base64.b64decode(base64_string) - # # convert base64 to PIL image - # return Image.open(io.BytesIO(imgdata)) + self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -81,7 +45,7 @@ class Api: b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) - return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.info) + return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.info) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): sampler_index = sampler_to_index(img2imgreq.sampler_index) @@ -120,10 +84,7 @@ class Api: processed = process_images(p) b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) - # for i in processed.images: - # buffer = io.BytesIO() - # i.save(buffer, format="png") - # b64images.append(base64.b64encode(buffer.getvalue())) + return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info) def extras_single_image_api(self, req: ExtrasSingleImageRequest): @@ -134,12 +95,12 @@ class Api: reqDict.pop('upscaler_1') reqDict.pop('upscaler_2') - reqDict['image'] = processing_utils.decode_base64_to_file(reqDict['image']) + reqDict['image'] = processing_utils.decode_base64_to_image(reqDict['image']) with self.queue_lock: result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") - return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): upscaler1Index = upscaler_to_index(req.upscaler_1) diff --git a/modules/api/models.py b/modules/api/models.py index 209f8af5..362e6277 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,10 +1,118 @@ -from pydantic import BaseModel, Field, Json +import inspect +from pydantic import BaseModel, Field, Json, create_model +from typing import Any, Optional from typing_extensions import Literal +from inflection import underscore +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.shared import sd_upscalers +API_NOT_ALLOWED = [ + "self", + "kwargs", + "sd_model", + "outpath_samples", + "outpath_grids", + "sampler_index", + "do_not_save_samples", + "do_not_save_grid", + "extra_generation_params", + "overlay_images", + "do_not_reload_embeddings", + "seed_enable_extras", + "prompt_for_display", + "sampler_noise_scheduler_override", + "ddim_discretize" +] + +class ModelDef(BaseModel): + """Assistance Class for Pydantic Dynamic Model Generation""" + + field: str + field_alias: str + field_type: Any + field_value: Any + + +class PydanticModelGenerator: + """ + Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: + source_data is a snapshot of the default values produced by the class + params are the names of the actual keys required by __init__ + """ + + def __init__( + self, + model_name: str = None, + class_instance = None, + additional_fields = None, + ): + def field_type_generator(k, v): + # field_type = str if not overrides.get(k) else overrides[k]["type"] + # print(k, v.annotation, v.default) + field_type = v.annotation + + return Optional[field_type] + + def merge_class_params(class_): + all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) + parameters = {} + for classes in all_classes: + parameters = {**parameters, **inspect.signature(classes.__init__).parameters} + return parameters + + + self._model_name = model_name + self._class_data = merge_class_params(class_instance) + self._model_def = [ + ModelDef( + field=underscore(k), + field_alias=k, + field_type=field_type_generator(k, v), + field_value=v.default + ) + for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED + ] + + for fields in additional_fields: + self._model_def.append(ModelDef( + field=underscore(fields["key"]), + field_alias=fields["key"], + field_type=fields["type"], + field_value=fields["default"])) + + def generate_model(self): + """ + Creates a pydantic BaseModel + from the json and overrides provided at initialization + """ + fields = { + d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def + } + DynamicModel = create_model(self._model_name, **fields) + DynamicModel.__config__.allow_population_by_field_name = True + DynamicModel.__config__.allow_mutation = True + return DynamicModel + +StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingTxt2Img", + StableDiffusionProcessingTxt2Img, + [{"key": "sampler_index", "type": str, "default": "Euler"}] +).generate_model() + +StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingImg2Img", + StableDiffusionProcessingImg2Img, + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] +).generate_model() + class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: str + parameters: dict + info: str + +class ImageToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: dict info: str class ExtrasBaseRequest(BaseModel): diff --git a/modules/api/processing.py b/modules/api/processing.py deleted file mode 100644 index f551fa35..00000000 --- a/modules/api/processing.py +++ /dev/null @@ -1,106 +0,0 @@ -from array import array -from inflection import underscore -from typing import Any, Dict, Optional -from pydantic import BaseModel, Field, create_model -from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img -import inspect - - -API_NOT_ALLOWED = [ - "self", - "kwargs", - "sd_model", - "outpath_samples", - "outpath_grids", - "sampler_index", - "do_not_save_samples", - "do_not_save_grid", - "extra_generation_params", - "overlay_images", - "do_not_reload_embeddings", - "seed_enable_extras", - "prompt_for_display", - "sampler_noise_scheduler_override", - "ddim_discretize" -] - -class ModelDef(BaseModel): - """Assistance Class for Pydantic Dynamic Model Generation""" - - field: str - field_alias: str - field_type: Any - field_value: Any - - -class PydanticModelGenerator: - """ - Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: - source_data is a snapshot of the default values produced by the class - params are the names of the actual keys required by __init__ - """ - - def __init__( - self, - model_name: str = None, - class_instance = None, - additional_fields = None, - ): - def field_type_generator(k, v): - # field_type = str if not overrides.get(k) else overrides[k]["type"] - # print(k, v.annotation, v.default) - field_type = v.annotation - - return Optional[field_type] - - def merge_class_params(class_): - all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) - parameters = {} - for classes in all_classes: - parameters = {**parameters, **inspect.signature(classes.__init__).parameters} - return parameters - - - self._model_name = model_name - self._class_data = merge_class_params(class_instance) - self._model_def = [ - ModelDef( - field=underscore(k), - field_alias=k, - field_type=field_type_generator(k, v), - field_value=v.default - ) - for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED - ] - - for fields in additional_fields: - self._model_def.append(ModelDef( - field=underscore(fields["key"]), - field_alias=fields["key"], - field_type=fields["type"], - field_value=fields["default"])) - - def generate_model(self): - """ - Creates a pydantic BaseModel - from the json and overrides provided at initialization - """ - fields = { - d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def - } - DynamicModel = create_model(self._model_name, **fields) - DynamicModel.__config__.allow_population_by_field_name = True - DynamicModel.__config__.allow_mutation = True - return DynamicModel - -StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( - "StableDiffusionProcessingTxt2Img", - StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}] -).generate_model() - -StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( - "StableDiffusionProcessingImg2Img", - StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] -).generate_model() \ No newline at end of file -- cgit v1.2.3 From 1e625624ba6ab3dfc167f0a5226780bb9b50fb58 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 16:01:16 -0300 Subject: Add folder processing endpoint Also minor refactor --- modules/api/api.py | 56 +++++++++++++++++++++++++++------------------------ modules/api/models.py | 6 +++++- 2 files changed, 35 insertions(+), 27 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 20e85e82..7b4fbe29 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,5 +1,5 @@ import uvicorn -from gradio import processing_utils +from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, HTTPException import modules.shared as shared from modules.api.models import * @@ -11,10 +11,18 @@ def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) except: - raise HTTPException(status_code=400, detail="Upscaler not found") + raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) +def setUpscalers(req: dict): + reqDict = vars(req) + reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) + reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2) + reqDict.pop('upscaler_1') + reqDict.pop('upscaler_2') + return reqDict + class Api: def __init__(self, app, queue_lock): self.router = APIRouter() @@ -24,6 +32,7 @@ class Api: self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) + self.app.add_api_route("/sdapi/v1/extra-folder-images", self.extras_folder_processing_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -43,7 +52,7 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.info) @@ -60,7 +69,7 @@ class Api: mask = img2imgreq.mask if mask: - mask = processing_utils.decode_base64_to_image(mask) + mask = decode_base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params @@ -75,7 +84,7 @@ class Api: imgs = [] for img in init_images: - img = processing_utils.decode_base64_to_image(img) + img = decode_base64_to_image(img) imgs = [img] * p.batch_size p.init_images = imgs @@ -83,43 +92,38 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info) def extras_single_image_api(self, req: ExtrasSingleImageRequest): - upscaler1Index = upscaler_to_index(req.upscaler_1) - upscaler2Index = upscaler_to_index(req.upscaler_2) - - reqDict = vars(req) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') + reqDict = setUpscalers(req) - reqDict['image'] = processing_utils.decode_base64_to_image(reqDict['image']) + reqDict['image'] = decode_base64_to_image(reqDict['image']) with self.queue_lock: - result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") + result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict) - return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): - upscaler1Index = upscaler_to_index(req.upscaler_1) - upscaler2Index = upscaler_to_index(req.upscaler_2) + reqDict = setUpscalers(req) - reqDict = vars(req) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') - - reqDict['image_folder'] = list(map(processing_utils.decode_base64_to_file, reqDict['imageList'])) + reqDict['image_folder'] = list(map(decode_base64_to_file, reqDict['imageList'])) reqDict.pop('imageList') with self.queue_lock: - result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=1, image="", input_dir="", output_dir="") + result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict) - return ExtrasBatchImagesResponse(images=list(map(processing_utils.encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) - def extras_folder_processing_api(self): - raise NotImplementedError + def extras_folder_processing_api(self, req:ExtrasFoldersRequest): + reqDict = setUpscalers(req) + + with self.queue_lock: + result = run_extras(extras_mode=2, image=None, image_folder=None, **reqDict) + + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index 362e6277..6f096807 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -150,4 +150,8 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest): imageList: list[str] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file + images: list[str] = Field(title="Images", description="The generated images in base64 format.") + +class ExtrasFoldersRequest(ExtrasBaseRequest): + input_dir: str = Field(title="Input directory", description="Directory path from where to take the images") + output_dir: str = Field(title="Output directory", description="Directory path to put the processsed images into") -- cgit v1.2.3 From 90f02c75220d187e075203a4e3b450bfba392c4d Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 16:03:30 -0300 Subject: Remove unused field and class --- modules/api/api.py | 6 +++--- modules/api/models.py | 6 +----- 2 files changed, 4 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 7b4fbe29..799e3701 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -104,7 +104,7 @@ class Api: with self.queue_lock: result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict) - return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): reqDict = setUpscalers(req) @@ -115,7 +115,7 @@ class Api: with self.queue_lock: result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict) - return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) def extras_folder_processing_api(self, req:ExtrasFoldersRequest): reqDict = setUpscalers(req) @@ -123,7 +123,7 @@ class Api: with self.queue_lock: result = run_extras(extras_mode=2, image=None, image_folder=None, **reqDict) - return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index 6f096807..e461d397 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -130,8 +130,7 @@ class ExtrasBaseRequest(BaseModel): extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") class ExtraBaseResponse(BaseModel): - html_info_x: str - html_info: str + html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.") class ExtrasSingleImageRequest(ExtrasBaseRequest): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") @@ -139,9 +138,6 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest): class ExtrasSingleImageResponse(ExtraBaseResponse): image: str = Field(default=None, title="Image", description="The generated image in base64 format.") -class SerializableImage(BaseModel): - path: str = Field(title="Path", description="The image's path ()") - class ImageItem(BaseModel): data: str = Field(title="image data") name: str = Field(title="filename") -- cgit v1.2.3 From 124e44cf1eed1edc68954f63a2a9bc428aabbcec Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 24 Oct 2022 09:51:56 +0800 Subject: remove browser to extension --- modules/images_history.py | 424 -------------------------------------------- modules/inspiration.py | 193 -------------------- modules/script_callbacks.py | 2 - modules/shared.py | 15 -- modules/ui.py | 20 +-- 5 files changed, 4 insertions(+), 650 deletions(-) delete mode 100644 modules/images_history.py delete mode 100644 modules/inspiration.py (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py deleted file mode 100644 index bc5cf11f..00000000 --- a/modules/images_history.py +++ /dev/null @@ -1,424 +0,0 @@ -import os -import shutil -import time -import hashlib -import gradio -system_bak_path = "webui_log_and_bak" -custom_tab_name = "custom fold" -faverate_tab_name = "favorites" -tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name] -def is_valid_date(date): - try: - time.strptime(date, "%Y%m%d") - return True - except: - return False - -def reduplicative_file_move(src, dst): - def same_name_file(basename, path): - name, ext = os.path.splitext(basename) - f_list = os.listdir(path) - max_num = 0 - for f in f_list: - if len(f) <= len(basename): - continue - f_ext = f[-len(ext):] if len(ext) > 0 else "" - if f[:len(name)] == name and f_ext == ext: - if f[len(name)] == "(" and f[-len(ext)-1] == ")": - number = f[len(name)+1:-len(ext)-1] - if number.isdigit(): - if int(number) > max_num: - max_num = int(number) - return f"{name}({max_num + 1}){ext}" - name = os.path.basename(src) - save_name = os.path.join(dst, name) - if not os.path.exists(save_name): - shutil.move(src, dst) - else: - name = same_name_file(name, dst) - shutil.move(src, os.path.join(dst, name)) - -def traverse_all_files(curr_path, image_list, all_type=False): - try: - f_list = os.listdir(curr_path) - except: - if all_type or (curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt" and curr_path[-4:] != ".csv"): - image_list.append(curr_path) - return image_list - for file in f_list: - file = os.path.join(curr_path, file) - if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"): - pass - elif os.path.isfile(file) and file[-10:].rfind(".") > 0: - image_list.append(file) - else: - image_list = traverse_all_files(file, image_list) - return image_list - -def auto_sorting(dir_name): - bak_path = os.path.join(dir_name, system_bak_path) - if not os.path.exists(bak_path): - os.mkdir(bak_path) - log_file = None - files_list = [] - f_list = os.listdir(dir_name) - for file in f_list: - if file == system_bak_path: - continue - file_path = os.path.join(dir_name, file) - if not is_valid_date(file): - if file[-10:].rfind(".") > 0: - files_list.append(file_path) - else: - files_list = traverse_all_files(file_path, files_list, all_type=True) - - for file in files_list: - date_str = time.strftime("%Y%m%d",time.localtime(os.path.getmtime(file))) - file_path = os.path.dirname(file) - hash_path = hashlib.md5(file_path.encode()).hexdigest() - path = os.path.join(dir_name, date_str, hash_path) - if not os.path.exists(path): - os.makedirs(path) - if log_file is None: - log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a") - log_file.write(f"{hash_path},{file_path}\n") - reduplicative_file_move(file, path) - - date_list = [] - f_list = os.listdir(dir_name) - for f in f_list: - if is_valid_date(f): - date_list.append(f) - elif f == system_bak_path: - continue - else: - try: - reduplicative_file_move(os.path.join(dir_name, f), bak_path) - except: - pass - - today = time.strftime("%Y%m%d",time.localtime(time.time())) - if today not in date_list: - date_list.append(today) - return sorted(date_list, reverse=True) - -def archive_images(dir_name, date_to): - filenames = [] - batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num) - if batch_size <= 0: - batch_size = opts.images_history_num_per_page * 6 - today = time.strftime("%Y%m%d",time.localtime(time.time())) - date_to = today if date_to is None or date_to == "" else date_to - date_to_bak = date_to - if False: #opts.images_history_reconstruct_directory: - date_list = auto_sorting(dir_name) - for date in date_list: - if date <= date_to: - path = os.path.join(dir_name, date) - if date == today and not os.path.exists(path): - continue - filenames = traverse_all_files(path, filenames) - if len(filenames) > batch_size: - break - filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file)) - else: - filenames = traverse_all_files(dir_name, filenames) - total_num = len(filenames) - tmparray = [(os.path.getmtime(file), file) for file in filenames ] - date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 - filenames = [] - date_list = {date_to:None} - date = time.strftime("%Y%m%d",time.localtime(time.time())) - for t, f in tmparray: - date = time.strftime("%Y%m%d",time.localtime(t)) - date_list[date] = None - if t <= date_stamp: - filenames.append((t, f ,date)) - date_list = sorted(list(date_list.keys()), reverse=True) - sort_array = sorted(filenames, key=lambda x:-x[0]) - if len(sort_array) > batch_size: - date = sort_array[batch_size][2] - filenames = [x[1] for x in sort_array] - else: - date = date_to if len(sort_array) == 0 else sort_array[-1][2] - filenames = [x[1] for x in sort_array] - filenames = [x[1] for x in sort_array if x[2]>= date] - num = len(filenames) - last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) - date = date[:4] + "/" + date[4:6] + "/" + date[6:8] - date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8] - load_info = "
" - load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" - load_info += "
" - _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) - return ( - date_to, - load_info, - filenames, - 1, - image_list, - "", - "", - visible_num, - last_date_from, - gradio.update(visible=total_num > num) - ) - -def delete_image(delete_num, name, filenames, image_index, visible_num): - if name == "": - return filenames, delete_num - else: - delete_num = int(delete_num) - visible_num = int(visible_num) - image_index = int(image_index) - index = list(filenames).index(name) - i = 0 - new_file_list = [] - for name in filenames: - if i >= index and i < index + delete_num: - if os.path.exists(name): - if visible_num == image_index: - new_file_list.append(name) - i += 1 - continue - print(f"Delete file {name}") - os.remove(name) - visible_num -= 1 - txt_file = os.path.splitext(name)[0] + ".txt" - if os.path.exists(txt_file): - os.remove(txt_file) - else: - print(f"Not exists file {name}") - else: - new_file_list.append(name) - i += 1 - return new_file_list, 1, visible_num - -def save_image(file_name): - if file_name is not None and os.path.exists(file_name): - shutil.copy(file_name, opts.outdir_save) - -def get_recent_images(page_index, step, filenames): - page_index = int(page_index) - num_of_imgs_per_page = int(opts.images_history_num_per_page) - max_page_index = len(filenames) // num_of_imgs_per_page + 1 - page_index = max_page_index if page_index == -1 else page_index + step - page_index = 1 if page_index < 1 else page_index - page_index = max_page_index if page_index > max_page_index else page_index - idx_frm = (page_index - 1) * num_of_imgs_per_page - image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page] - length = len(filenames) - visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page - visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num - return page_index, image_list, "", "", visible_num - -def loac_batch_click(date_to): - if date_to is None: - return time.strftime("%Y%m%d",time.localtime(time.time())), [] - else: - return None, [] -def forward_click(last_date_from, date_to_recorder): - if len(date_to_recorder) == 0: - return None, [] - if last_date_from == date_to_recorder[-1]: - date_to_recorder = date_to_recorder[:-1] - if len(date_to_recorder) == 0: - return None, [] - return date_to_recorder[-1], date_to_recorder[:-1] - -def backward_click(last_date_from, date_to_recorder): - if last_date_from is None or last_date_from == "": - return time.strftime("%Y%m%d",time.localtime(time.time())), [] - if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]: - date_to_recorder.append(last_date_from) - return last_date_from, date_to_recorder - - -def first_page_click(page_index, filenames): - return get_recent_images(1, 0, filenames) - -def end_page_click(page_index, filenames): - return get_recent_images(-1, 0, filenames) - -def prev_page_click(page_index, filenames): - return get_recent_images(page_index, -1, filenames) - -def next_page_click(page_index, filenames): - return get_recent_images(page_index, 1, filenames) - -def page_index_change(page_index, filenames): - return get_recent_images(page_index, 0, filenames) - -def show_image_info(tabname_box, num, page_index, filenames): - file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] - tm = "
" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "
" - return file, tm, num, file - -def enable_page_buttons(): - return gradio.update(visible=True) - -def change_dir(img_dir, date_to): - warning = None - try: - if os.path.exists(img_dir): - try: - f = os.listdir(img_dir) - except: - warning = f"'{img_dir} is not a directory" - else: - warning = "The directory is not exist" - except: - warning = "The format of the directory is incorrect" - if warning is None: - today = time.strftime("%Y%m%d",time.localtime(time.time())) - return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True) - else: - return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False) - -def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): - custom_dir = False - if tabname == "txt2img": - dir_name = opts.outdir_txt2img_samples - elif tabname == "img2img": - dir_name = opts.outdir_img2img_samples - elif tabname == "extras": - dir_name = opts.outdir_extras_samples - elif tabname == faverate_tab_name: - dir_name = opts.outdir_save - else: - custom_dir = True - dir_name = None - - if not custom_dir: - d = dir_name.split("/") - dir_name = d[0] - for p in d[1:]: - dir_name = os.path.join(dir_name, p) - if not os.path.exists(dir_name): - os.makedirs(dir_name) - - with gr.Column() as page_panel: - with gr.Row(): - with gr.Column(scale=1, visible=not custom_dir) as load_batch_box: - load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True) - with gr.Column(scale=4): - with gr.Row(): - img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) - with gr.Row(): - with gr.Column(visible=False, scale=1) as batch_panel: - with gr.Row(): - forward = gr.Button('Prev batch') - backward = gr.Button('Next batch') - with gr.Column(scale=3): - load_info = gr.HTML(visible=not custom_dir) - with gr.Row(visible=False) as warning: - warning_box = gr.Textbox("Message", interactive=False) - - with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: - with gr.Column(scale=2): - with gr.Row(visible=True) as turn_page_buttons: - #date_to = gr.Dropdown(label="Date to") - first_page = gr.Button('First Page') - prev_page = gr.Button('Prev Page') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next Page') - end_page = gr.Button('End Page') - - history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num) - with gr.Row(): - delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") - delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - - with gr.Column(): - with gr.Row(): - with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6) - gr.HTML("
") - img_file_name = gr.Textbox(value="", label="File Name", interactive=False) - img_file_time= gr.HTML() - with gr.Row(): - if tabname != faverate_tab_name: - save_btn = gr.Button('Collect') - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') - - - # hiden items - with gr.Row(visible=False): - renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") - batch_date_to = gr.Textbox(label="Date to") - visible_img_num = gr.Number() - date_to_recorder = gr.State([]) - last_date_from = gr.Textbox() - tabname_box = gr.Textbox(tabname) - image_index = gr.Textbox(value=-1) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") - filenames = gr.State() - all_images_list = gr.State() - hidden = gr.Image(type="pil") - info1 = gr.Textbox() - info2 = gr.Textbox() - - img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info]) - - #change batch - change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel] - - batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output) - batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) - batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - - load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder]) - forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) - backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) - - - #delete - delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) - delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None) - if tabname != faverate_tab_name: - save_btn.click(save_image, inputs=[img_file_name], outputs=None) - - #turn page - gallery_inputs = [page_index, filenames] - gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num] - first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) - renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) - - first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - - # other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden]) - img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) - hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) - switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') - switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') - - - -def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict): - global opts; - opts = sys_opts - loads_files_num = int(opts.images_history_num_per_page) - num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) - if cmp_ops.browse_all_images: - tabs_list.append(custom_tab_name) - with gr.Blocks(analytics_enabled=False) as images_history: - with gr.Tabs() as tabs: - for tab in tabs_list: - with gr.Tab(tab): - with gr.Blocks(analytics_enabled=False) : - show_images_history(gr, opts, tab, run_pnginfo, switch_dict) - gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False) - gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False) - - return images_history diff --git a/modules/inspiration.py b/modules/inspiration.py deleted file mode 100644 index 29cf8297..00000000 --- a/modules/inspiration.py +++ /dev/null @@ -1,193 +0,0 @@ -import os -import random -import gradio -from modules.shared import opts -inspiration_system_path = os.path.join(opts.inspiration_dir, "system") -def read_name_list(file, types=None, keyword=None): - if not os.path.exists(file): - return [] - ret = [] - f = open(file, "r") - line = f.readline() - while len(line) > 0: - line = line.rstrip("\n") - if types is not None: - dirname = os.path.split(line) - if dirname[0] in types and keyword in dirname[1].lower(): - ret.append(line) - else: - ret.append(line) - line = f.readline() - return ret - -def save_name_list(file, name): - name_list = read_name_list(file) - if name not in name_list: - with open(file, "a") as f: - f.write(name + "\n") - -def get_types_list(): - files = os.listdir(opts.inspiration_dir) - types = [] - for x in files: - path = os.path.join(opts.inspiration_dir, x) - if x[0] == ".": - continue - if not os.path.isdir(path): - continue - if path == inspiration_system_path: - continue - types.append(x) - return types - -def get_inspiration_images(source, types, keyword): - keyword = keyword.strip(" ").lower() - get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num) - if source == "Favorites": - names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword) - names = random.sample(names, get_num) if len(names) > get_num else names - elif source == "Abandoned": - names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) - names = random.sample(names, get_num) if len(names) > get_num else names - elif source == "Exclude abandoned": - abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) - all_names = [] - for tp in types: - name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) - all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()] - - if len(all_names) > get_num: - names = [] - while len(names) < get_num: - name = random.choice(all_names) - if name not in abandoned: - names.append(name) - else: - names = all_names - else: - all_names = [] - for tp in types: - name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) - all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()] - names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names - image_list = [] - for a in names: - image_path = os.path.join(opts.inspiration_dir, a) - images = os.listdir(image_path) - if len(images) > 0: - image_list.append((os.path.join(image_path, random.choice(images)), a)) - else: - print(image_path) - return image_list, names - -def select_click(index, name_list): - name = name_list[int(index)] - path = os.path.join(opts.inspiration_dir, name) - images = os.listdir(path) - return name, [os.path.join(path, x) for x in images], "" - -def give_up_click(name): - file = os.path.join(inspiration_system_path, "abandoned.txt") - save_name_list(file, name) - return "Added to abandoned list" - -def collect_click(name): - file = os.path.join(inspiration_system_path, "faverites.txt") - save_name_list(file, name) - return "Added to faverite list" - -def moveout_click(name, source): - if source == "Abandoned": - file = os.path.join(inspiration_system_path, "abandoned.txt") - elif source == "Favorites": - file = os.path.join(inspiration_system_path, "faverites.txt") - else: - return None - name_list = read_name_list(file) - os.remove(file) - with open(file, "a") as f: - for a in name_list: - if a != name: - f.write(a + "\n") - return f"Moved out {name} from {source} list" - -def source_change(source): - if source in ["Abandoned", "Favorites"]: - return gradio.update(visible=True), [] - else: - return gradio.update(visible=False), [] -def add_to_prompt(name, prompt): - name = os.path.basename(name) - return prompt + "," + name - -def clear_keyword(): - return "" - -def ui(gr, opts, txt2img_prompt, img2img_prompt): - with gr.Blocks(analytics_enabled=False) as inspiration: - flag = os.path.exists(opts.inspiration_dir) - if flag: - types = get_types_list() - flag = len(types) > 0 - else: - os.makedirs(opts.inspiration_dir) - if not flag: - gr.HTML(""" -

To activate inspiration function, you need get "inspiration" images first.


- You can create these images by run "Create inspiration images" script in txt2img page,
you can get the artists or art styles list from here
- https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
- download these files, and select these files in the "Create inspiration images" script UI
- There about 6000 artists and art styles in these files.
This takes server hours depending on your GPU type and how many pictures you generate for each artist/style -
I suggest at least four images for each


-

You can also download generated pictures from here:


- https://huggingface.co/datasets/yfszzx/inspiration
- unzip the file to the project directory of webui
- and restart webui, and enjoy the joy of creation!
- """) - return inspiration - if not os.path.exists(inspiration_system_path): - os.mkdir(inspiration_system_path) - with gr.Row(): - with gr.Column(scale=2): - inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto') - with gr.Column(scale=1): - types = gr.CheckboxGroup(choices=types, value=types) - with gr.Row(): - source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") - keyword = gr.Textbox("", label="Key word") - get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button") - name = gr.Textbox(show_label=False, interactive=False) - with gr.Row(): - send_to_txt2img = gr.Button('to txt2img') - send_to_img2img = gr.Button('to img2img') - collect = gr.Button('Collect') - give_up = gr.Button("Don't show again") - moveout = gr.Button("Move out", visible=False) - warning = gr.HTML() - style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') - - - - with gr.Row(visible=False): - select_button = gr.Button('set button', elem_id="inspiration_select_button") - name_list = gr.State() - - get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list]) - keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) - source.change(source_change, inputs=[source], outputs=[moveout, style_gallery]) - source.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword]) - types.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword]) - - select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning]) - give_up.click(give_up_click, inputs=[name], outputs=[warning]) - collect.click(collect_click, inputs=[name], outputs=[warning]) - moveout.click(moveout_click, inputs=[name, source], outputs=[warning]) - moveout.click(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) - - send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt]) - send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt]) - send_to_txt2img.click(collect_click, inputs=[name], outputs=[warning]) - send_to_img2img.click(collect_click, inputs=[name], outputs=[warning]) - send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None) - send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None) - return inspiration diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 5bcccd67..66666a56 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,4 +1,3 @@ - callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] @@ -16,7 +15,6 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] - for callback in callbacks_ui_tabs: res += callback() or [] diff --git a/modules/shared.py b/modules/shared.py index 0aaaadac..5dfd7927 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -321,21 +321,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) -options_templates.update(options_section(('inspiration', "Inspiration"), { - "inspiration_dir": OptionInfo("inspiration", "Directory of inspiration", component_args=hide_dirs), - "inspiration_max_samples": OptionInfo(4, "Maximum number of samples, used to determine which folders to skip when continue running the create script", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), - "inspiration_rows_num": OptionInfo(4, "Rows of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), - "inspiration_cols_num": OptionInfo(8, "Columns of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), -})) - -options_templates.update(options_section(('images-history', "Images Browser"), { - #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), - "images_history_preload": OptionInfo(False, "Preload images at startup"), - "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"), - "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "), - "images_history_grid_num": OptionInfo(6, "Number of grids in each row"), - -})) class Options: data = None diff --git a/modules/ui.py b/modules/ui.py index a73175f5..fa42712e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -49,14 +49,12 @@ from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.images_history as images_history -import modules.inspiration as inspiration - - # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') +txt2img_paste_fields = [] +img2img_paste_fields = [] if not cmd_opts.share and not cmd_opts.listen: @@ -1193,16 +1191,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[image], outputs=[html, generation_info, html2], ) - #images history - images_history_switch_dict = { - "fn": modules.generation_parameters_copypaste.connect_paste, - "t2i": txt2img_paste_fields, - "i2i": img2img_paste_fields - } - - browser_interface = images_history.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) - inspiration_interface = inspiration.ui(gr, opts, txt2img_prompt, img2img_prompt) - + with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): @@ -1651,8 +1640,6 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (inspiration_interface, "Inspiration", "inspiration"), - (browser_interface , "Image Browser", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), ] @@ -1896,6 +1883,7 @@ def load_javascript(raw_response): javascript = f'' scripts_list = modules.scripts.list_scripts("javascript", ".js") + scripts_list += modules.scripts.list_scripts("scripts", ".js") for basedir, filename, path in scripts_list: with open(path, "r", encoding="utf8") as jsfile: javascript += f"\n" -- cgit v1.2.3 From cef1b89aa2e6c7647db7e93a4cd4ec020da3f2da Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 24 Oct 2022 10:10:33 +0800 Subject: remove browser to extension --- modules/script_callbacks.py | 2 ++ modules/shared.py | 1 - modules/ui.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 66666a56..f46d3d9a 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,3 +1,4 @@ + callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] @@ -15,6 +16,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] + for callback in callbacks_ui_tabs: res += callback() or [] diff --git a/modules/shared.py b/modules/shared.py index 5dfd7927..6541e679 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -82,7 +82,6 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) -parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False) cmd_opts = parser.parse_args() restricted_opts = [ diff --git a/modules/ui.py b/modules/ui.py index fa42712e..a32f7259 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1104,7 +1104,7 @@ def create_ui(wrap_gradio_gpu_call): upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers] , value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") -- cgit v1.2.3 From a889c93f23f1e80d0dac4e5ddbc3a26207e8cdf1 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 24 Oct 2022 11:13:16 +0800 Subject: paste_fields add to public --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a32f7259..a73b9ff0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -784,6 +784,7 @@ def create_ui(wrap_gradio_gpu_call): ] ) + global txt2img_paste_fields txt2img_paste_fields = [ (txt2img_prompt, "Prompt"), (txt2img_negative_prompt, "Negative prompt"), @@ -1054,6 +1055,7 @@ def create_ui(wrap_gradio_gpu_call): outputs=[prompt, negative_prompt, style1, style2], ) + global img2img_paste_fields img2img_paste_fields = [ (img2img_prompt, "Prompt"), (img2img_negative_prompt, "Negative prompt"), -- cgit v1.2.3 From 974196932583b96b6b76632052fc0d7e70820bf3 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sun, 23 Oct 2022 22:38:42 +0300 Subject: Save properly processed image before color correction --- modules/processing.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index ff83023c..15b639e1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -46,6 +46,20 @@ def apply_color_correction(correction, image): return image +def apply_overlay(overlay_exists, overlay, paste_loc, image): + if overlay_exists: + if paste_loc is not None: + x, y, w, h = paste_loc + base_image = Image.new('RGBA', (overlay.width, overlay.height)) + image = images.resize_image(1, image, w, h) + base_image.paste(image, (x, y)) + image = base_image + + image = image.convert('RGBA') + image.alpha_composite(overlay) + image = image.convert('RGB') + + return image def get_correct_sampler(p): if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img): @@ -446,25 +460,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() image = Image.fromarray(x_sample) - + if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: - images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") + image_without_cc = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image) + images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) - if p.overlay_images is not None and i < len(p.overlay_images): - overlay = p.overlay_images[i] - - if p.paste_to is not None: - x, y, w, h = p.paste_to - base_image = Image.new('RGBA', (overlay.width, overlay.height)) - image = images.resize_image(1, image, w, h) - base_image.paste(image, (x, y)) - image = base_image - - image = image.convert('RGBA') - image.alpha_composite(overlay) - image = image.convert('RGB') + image = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image) if opts.samples_save and not p.do_not_save_samples: images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) -- cgit v1.2.3 From f2cc3f32d5bc8538e95edec54d7dc1b9efdf769a Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sun, 23 Oct 2022 22:44:46 +0300 Subject: fix whitespaces --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 15b639e1..2a332514 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -460,7 +460,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() image = Image.fromarray(x_sample) - + if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: image_without_cc = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image) -- cgit v1.2.3 From b297cc3324979ec78d69b2d11dd18030dfad7bcc Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 20:06:42 +0900 Subject: Hypernetworks - fix KeyError in statistics caching Statistics logging has changed to {filename : list[losses]}, so it has to use loss_info[key].pop() --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 98a7b62e..33827210 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -274,8 +274,8 @@ def log_statistics(loss_info:dict, key, value): loss_info[key] = [value] else: loss_info[key].append(value) - if len(loss_info) > 1024: - loss_info.pop(0) + if len(loss_info[key]) > 1024: + loss_info[key].pop(0) def statistics(data): -- cgit v1.2.3 From 40b56c9289bf9458ae5ef3c1990ccea851c6c3e2 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:07:07 +0900 Subject: cleanup some code --- modules/hypernetworks/hypernetwork.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 33827210..4072bf54 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum +from collections import defaultdict, deque from statistics import stdev, mean class HypernetworkModule(torch.nn.Module): @@ -269,15 +270,6 @@ def stack_conds(conds): return torch.stack(conds) -def log_statistics(loss_info:dict, key, value): - if key not in loss_info: - loss_info[key] = [value] - else: - loss_info[key].append(value) - if len(loss_info[key]) > 1024: - loss_info[key].pop(0) - - def statistics(data): total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})" recent_data = data[-32:] @@ -341,7 +333,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log weight.requires_grad = True size = len(ds.indexes) - loss_dict = {} + loss_dict = defaultdict(lambda : deque(maxlen = 1024)) losses = torch.zeros((size,)) previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) @@ -383,7 +375,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item() for entry in entries: - log_statistics(loss_dict, entry.filename, loss.item()) + loss_dict[entry.filename].append(loss.item()) optimizer.zero_grad() weights[0].grad = None -- cgit v1.2.3 From 348f89c8d40397c1875cff4a7331018785f9c3b8 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:29:53 +0900 Subject: statistics for pbar --- modules/hypernetworks/hypernetwork.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4072bf54..48b56029 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -335,6 +335,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log size = len(ds.indexes) loss_dict = defaultdict(lambda : deque(maxlen = 1024)) losses = torch.zeros((size,)) + previous_mean_losses = [0] previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) @@ -356,7 +357,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log for i, entries in pbar: hypernetwork.step = i + ititial_step if len(loss_dict) > 0: - previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict) + previous_mean_losses = [i[-1] for i in loss_dict.values()] + previous_mean_loss = mean(previous_mean_losses) scheduler.apply(optimizer, hypernetwork.step) if scheduler.finished: @@ -391,7 +393,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): raise RuntimeError("Loss diverged.") - pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}") + + if len(previous_mean_losses) > 1: + std = stdev(previous_mean_losses) + else: + std = 0 + dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" + pbar.set_description(dataset_loss_info) if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. -- cgit v1.2.3 From 0d2e1dac407a0e2f5b148d314715f0457b2525b7 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:41:39 +0900 Subject: convert deque -> list I don't feel this being efficient --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 48b56029..fb510fa7 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -282,7 +282,7 @@ def report_statistics(loss_info:dict): for key in keys: try: print("Loss statistics for file " + key) - info, recent = statistics(loss_info[key]) + info, recent = statistics(list(loss_info[key])) print(info) print(recent) except Exception as e: -- cgit v1.2.3 From e9a410b5357612f63528015c5533c2185dcff92e Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:47:39 +0900 Subject: check length for variance --- modules/hypernetworks/hypernetwork.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index fb510fa7..d647ea55 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -271,9 +271,17 @@ def stack_conds(conds): def statistics(data): - total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})" + if len(data) < 2: + std = 0 + else: + std = stdev(data) + total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})" recent_data = data[-32:] - recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})" + if len(recent_data) < 2: + std = 0 + else: + std = stdev(recent_data) + recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})" return total_information, recent_information -- cgit v1.2.3 From 6cbb04f7a5e675cf1f6dfc247aa9c9e8df7dc5ce Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 24 Oct 2022 09:15:26 +0300 Subject: fix #3517 breaking txt2img --- modules/processing.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 2a332514..c61bbfbd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -46,18 +46,23 @@ def apply_color_correction(correction, image): return image -def apply_overlay(overlay_exists, overlay, paste_loc, image): - if overlay_exists: - if paste_loc is not None: - x, y, w, h = paste_loc - base_image = Image.new('RGBA', (overlay.width, overlay.height)) - image = images.resize_image(1, image, w, h) - base_image.paste(image, (x, y)) - image = base_image - - image = image.convert('RGBA') - image.alpha_composite(overlay) - image = image.convert('RGB') + +def apply_overlay(image, paste_loc, index, overlays): + if overlays is None or index >= len(overlays): + return image + + overlay = overlays[index] + + if paste_loc is not None: + x, y, w, h = paste_loc + base_image = Image.new('RGBA', (overlay.width, overlay.height)) + image = images.resize_image(1, image, w, h) + base_image.paste(image, (x, y)) + image = base_image + + image = image.convert('RGBA') + image.alpha_composite(overlay) + image = image.convert('RGB') return image @@ -463,11 +468,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: - image_without_cc = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image) + image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) - image = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image) + image = apply_overlay(image, p.paste_to, i, p.overlay_images) if opts.samples_save and not p.do_not_save_samples: images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) -- cgit v1.2.3 From 734986dde3231416813f827242c111da212b2ccb Mon Sep 17 00:00:00 2001 From: Trung Ngo Date: Mon, 24 Oct 2022 01:17:09 -0500 Subject: add callback after image is saved --- modules/images.py | 3 ++- modules/script_callbacks.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index b9589563..01c60f89 100644 --- a/modules/images.py +++ b/modules/images.py @@ -12,7 +12,7 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from fonts.ttf import Roboto import string -from modules import sd_samplers, shared +from modules import sd_samplers, shared, script_callbacks from modules.shared import opts, cmd_opts LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) @@ -467,6 +467,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else: txt_fullfn = None + script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn) return fullfn, txt_fullfn diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 5bcccd67..5836e4b9 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,11 +2,12 @@ callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] - +callbacks_image_saved = [] def clear_callbacks(): callbacks_model_loaded.clear() callbacks_ui_tabs.clear() + callbacks_image_saved.clear() def model_loaded_callback(sd_model): @@ -28,6 +29,10 @@ def ui_settings_callback(): callback() +def image_saved_callback(image, p, fullfn, txt_fullfn): + for callback in callbacks_image_saved: + callback(image, p, fullfn, txt_fullfn) + def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" @@ -51,3 +56,8 @@ def on_ui_settings(callback): """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ callbacks_ui_settings.append(callback) + + +def on_save_imaged(callback): + """register a function to call after modules.images.save_image is called returning same values, original image and p """ + callbacks_image_saved.append(callback) -- cgit v1.2.3 From 876a96f0f9843382ebc8984db3de5d8af0e9ce4c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 24 Oct 2022 09:39:46 +0300 Subject: remove erroneous dir in the extension directory remove loading .js files from scripts dir (they go into javascript) load scripts after models, for scripts that depend on loaded models --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a73b9ff0..03528968 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1885,7 +1885,7 @@ def load_javascript(raw_response): javascript = f'' scripts_list = modules.scripts.list_scripts("javascript", ".js") - scripts_list += modules.scripts.list_scripts("scripts", ".js") + for basedir, filename, path in scripts_list: with open(path, "r", encoding="utf8") as jsfile: javascript += f"\n" -- cgit v1.2.3 From 3be6b29d81408d2adb741bff5b11c80214aa621e Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 24 Oct 2022 15:14:34 +0900 Subject: indent=4 config.json --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 6541e679..d6ddfe59 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -348,7 +348,7 @@ class Options: def save(self, filename): with open(filename, "w", encoding="utf8") as file: - json.dump(self.data, file) + json.dump(self.data, file, indent=4) def same_type(self, x, y): if x is None or y is None: -- cgit v1.2.3 From c5d90628a4058bf49c2fdabf620a24db73407f31 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 22 Oct 2022 17:16:55 +0900 Subject: move "file_decoration" initialize section into "if forced_filename is None:" no need to initialize it if it's not going to be used --- modules/images.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index b9589563..50a59cff 100644 --- a/modules/images.py +++ b/modules/images.py @@ -386,18 +386,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i txt_fullfn (`str` or None): If a text file is saved for this image, this will be its full path. Otherwise None. ''' - if short_filename or prompt is None or seed is None: - file_decoration = "" - elif opts.save_to_dirs: - file_decoration = opts.samples_filename_pattern or "[seed]" - else: - file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]" - - if file_decoration != "": - file_decoration = "-" + file_decoration.lower() - - file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix - if extension == 'png' and opts.enable_pnginfo and info is not None: pnginfo = PngImagePlugin.PngInfo() @@ -419,6 +407,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i os.makedirs(path, exist_ok=True) if forced_filename is None: + if short_filename or prompt is None or seed is None: + file_decoration = "" + elif opts.save_to_dirs: + file_decoration = opts.samples_filename_pattern or "[seed]" + else: + file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]" + + if file_decoration != "": + file_decoration = "-" + file_decoration.lower() + + file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix + basecount = get_next_sequence_number(path, basename) fullfn = "a.png" fullfn_without_extension = "a" -- cgit v1.2.3 From 7d4a4db9ea7543c079f4a4a702c2945f4b66cd11 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 22 Oct 2022 17:48:59 +0900 Subject: modify unnecessary sting assignment as it's going to get overwritten --- modules/images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 50a59cff..cc5066b1 100644 --- a/modules/images.py +++ b/modules/images.py @@ -420,8 +420,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix basecount = get_next_sequence_number(path, basename) - fullfn = "a.png" - fullfn_without_extension = "a" + fullfn = None + fullfn_without_extension = None for i in range(500): fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}" fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}") -- cgit v1.2.3 From 37dd6deafb831a809eaf7ae8d232937a8c7998e7 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 22 Oct 2022 21:11:15 +0900 Subject: filename pattern [datetime], extended customizable Format and Time Zone format: [datetime] [datetime] [datetime