From ad4de819c43997f2666b5bad95301f5c37f9018e Mon Sep 17 00:00:00 2001 From: victorca25 Date: Sun, 9 Oct 2022 13:02:12 +0200 Subject: update ESRGAN architecture and model to support all ESRGAN models in the DB, BSRGAN and real-ESRGAN models --- modules/bsrgan_model.py | 76 ------- modules/bsrgan_model_arch.py | 102 ---------- modules/esrgam_model_arch.py | 80 -------- modules/esrgan_model.py | 190 ++++++++++++------ modules/esrgan_model_arch.py | 463 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 591 insertions(+), 320 deletions(-) delete mode 100644 modules/bsrgan_model.py delete mode 100644 modules/bsrgan_model_arch.py delete mode 100644 modules/esrgam_model_arch.py create mode 100644 modules/esrgan_model_arch.py (limited to 'modules') diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py deleted file mode 100644 index 737e1a76..00000000 --- a/modules/bsrgan_model.py +++ /dev/null @@ -1,76 +0,0 @@ -import os.path -import sys -import traceback - -import PIL.Image -import numpy as np -import torch -from basicsr.utils.download_util import load_file_from_url - -import modules.upscaler -from modules import devices, modelloader -from modules.bsrgan_model_arch import RRDBNet - - -class UpscalerBSRGAN(modules.upscaler.Upscaler): - def __init__(self, dirname): - self.name = "BSRGAN" - self.model_name = "BSRGAN 4x" - self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" - self.user_path = dirname - super().__init__() - model_paths = self.find_models(ext_filter=[".pt", ".pth"]) - scalers = [] - if len(model_paths) == 0: - scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4) - scalers.append(scaler_data) - for file in model_paths: - if "http" in file: - name = self.model_name - else: - name = modelloader.friendly_name(file) - try: - scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) - scalers.append(scaler_data) - except Exception: - print(f"Error loading BSRGAN model: {file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - self.scalers = scalers - - def do_upscale(self, img: PIL.Image, selected_file): - torch.cuda.empty_cache() - model = self.load_model(selected_file) - if model is None: - return img - model.to(devices.device_bsrgan) - torch.cuda.empty_cache() - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_bsrgan) - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - torch.cuda.empty_cache() - return PIL.Image.fromarray(output, 'RGB') - - def load_model(self, path: str): - if "http" in path: - filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, - progress=True) - else: - filename = path - if not os.path.exists(filename) or filename is None: - print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr) - return None - model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for k, v in model.named_parameters(): - v.requires_grad = False - return model - diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py deleted file mode 100644 index cb4d1c13..00000000 --- a/modules/bsrgan_model_arch.py +++ /dev/null @@ -1,102 +0,0 @@ -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.nn.init as init - - -def initialize_weights(net_l, scale=1): - if not isinstance(net_l, list): - net_l = [net_l] - for net in net_l: - for m in net.modules(): - if isinstance(m, nn.Conv2d): - init.kaiming_normal_(m.weight, a=0, mode='fan_in') - m.weight.data *= scale # for residual block - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - init.kaiming_normal_(m.weight, a=0, mode='fan_in') - m.weight.data *= scale - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - init.constant_(m.weight, 1) - init.constant_(m.bias.data, 0.0) - - -def make_layer(block, n_layers): - layers = [] - for _ in range(n_layers): - layers.append(block()) - return nn.Sequential(*layers) - - -class ResidualDenseBlock_5C(nn.Module): - def __init__(self, nf=64, gc=32, bias=True): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) - self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) - self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) - self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) - self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - # initialization - initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) - - def forward(self, x): - x1 = self.lrelu(self.conv1(x)) - x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) - x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 * 0.2 + x - - -class RRDB(nn.Module): - '''Residual in Residual Dense Block''' - - def __init__(self, nf, gc=32): - super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) - - def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - return out * 0.2 + x - - -class RRDBNet(nn.Module): - def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4): - super(RRDBNet, self).__init__() - RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) - self.sf = sf - - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.RRDB_trunk = make_layer(RRDB_block_f, nb) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - if self.sf==4: - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x): - fea = self.conv_first(x) - trunk = self.trunk_conv(self.RRDB_trunk(fea)) - fea = fea + trunk - - fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) - if self.sf==4: - fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) - out = self.conv_last(self.lrelu(self.HRconv(fea))) - - return out \ No newline at end of file diff --git a/modules/esrgam_model_arch.py b/modules/esrgam_model_arch.py deleted file mode 100644 index e413d36e..00000000 --- a/modules/esrgam_model_arch.py +++ /dev/null @@ -1,80 +0,0 @@ -# this file is taken from https://github.com/xinntao/ESRGAN - -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def make_layer(block, n_layers): - layers = [] - for _ in range(n_layers): - layers.append(block()) - return nn.Sequential(*layers) - - -class ResidualDenseBlock_5C(nn.Module): - def __init__(self, nf=64, gc=32, bias=True): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) - self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) - self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) - self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) - self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - # initialization - # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) - - def forward(self, x): - x1 = self.lrelu(self.conv1(x)) - x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) - x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 * 0.2 + x - - -class RRDB(nn.Module): - '''Residual in Residual Dense Block''' - - def __init__(self, nf, gc=32): - super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) - - def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - return out * 0.2 + x - - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32): - super(RRDBNet, self).__init__() - RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) - - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.RRDB_trunk = make_layer(RRDB_block_f, nb) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x): - fea = self.conv_first(x) - trunk = self.trunk_conv(self.RRDB_trunk(fea)) - fea = fea + trunk - - fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) - fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) - out = self.conv_last(self.lrelu(self.HRconv(fea))) - - return out diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 3970e6e4..a49e2258 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -5,68 +5,115 @@ import torch from PIL import Image from basicsr.utils.download_util import load_file_from_url -import modules.esrgam_model_arch as arch +import modules.esrgan_model_arch as arch from modules import shared, modelloader, images, devices from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts -def fix_model_layers(crt_model, pretrained_net): - # this code is adapted from https://github.com/xinntao/ESRGAN - if 'conv_first.weight' in pretrained_net: - return pretrained_net - if 'model.0.weight' not in pretrained_net: - is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"] - if is_realesrgan: - raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.") - else: - raise Exception("The file is not a ESRGAN model.") +def mod2normal(state_dict): + # this code is copied from https://github.com/victorca25/iNNfer + if 'conv_first.weight' in state_dict: + crt_net = {} + items = [] + for k, v in state_dict.items(): + items.append(k) + + crt_net['model.0.weight'] = state_dict['conv_first.weight'] + crt_net['model.0.bias'] = state_dict['conv_first.bias'] + + for k in items.copy(): + if 'RDB' in k: + ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[ori_k] = state_dict[k] + items.remove(k) + + crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight'] + crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias'] + crt_net['model.3.weight'] = state_dict['upconv1.weight'] + crt_net['model.3.bias'] = state_dict['upconv1.bias'] + crt_net['model.6.weight'] = state_dict['upconv2.weight'] + crt_net['model.6.bias'] = state_dict['upconv2.bias'] + crt_net['model.8.weight'] = state_dict['HRconv.weight'] + crt_net['model.8.bias'] = state_dict['HRconv.bias'] + crt_net['model.10.weight'] = state_dict['conv_last.weight'] + crt_net['model.10.bias'] = state_dict['conv_last.bias'] + state_dict = crt_net + return state_dict + + +def resrgan2normal(state_dict, nb=23): + # this code is copied from https://github.com/victorca25/iNNfer + if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: + crt_net = {} + items = [] + for k, v in state_dict.items(): + items.append(k) + + crt_net['model.0.weight'] = state_dict['conv_first.weight'] + crt_net['model.0.bias'] = state_dict['conv_first.bias'] + + for k in items.copy(): + if "rdb" in k: + ori_k = k.replace('body.', 'model.1.sub.') + ori_k = ori_k.replace('.rdb', '.RDB') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[ori_k] = state_dict[k] + items.remove(k) + + crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight'] + crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias'] + crt_net['model.3.weight'] = state_dict['conv_up1.weight'] + crt_net['model.3.bias'] = state_dict['conv_up1.bias'] + crt_net['model.6.weight'] = state_dict['conv_up2.weight'] + crt_net['model.6.bias'] = state_dict['conv_up2.bias'] + crt_net['model.8.weight'] = state_dict['conv_hr.weight'] + crt_net['model.8.bias'] = state_dict['conv_hr.bias'] + crt_net['model.10.weight'] = state_dict['conv_last.weight'] + crt_net['model.10.bias'] = state_dict['conv_last.bias'] + state_dict = crt_net + return state_dict + + +def infer_params(state_dict): + # this code is copied from https://github.com/victorca25/iNNfer + scale2x = 0 + scalemin = 6 + n_uplayer = 0 + plus = False + + for block in list(state_dict): + parts = block.split(".") + n_parts = len(parts) + if n_parts == 5 and parts[2] == "sub": + nb = int(parts[3]) + elif n_parts == 3: + part_num = int(parts[1]) + if (part_num > scalemin + and parts[0] == "model" + and parts[2] == "weight"): + scale2x += 1 + if part_num > n_uplayer: + n_uplayer = part_num + out_nc = state_dict[block].shape[0] + if not plus and "conv1x1" in block: + plus = True + + nf = state_dict["model.0.weight"].shape[0] + in_nc = state_dict["model.0.weight"].shape[1] + out_nc = out_nc + scale = 2 ** scale2x + + return in_nc, out_nc, nf, nb, plus, scale - crt_net = crt_model.state_dict() - load_net_clean = {} - for k, v in pretrained_net.items(): - if k.startswith('module.'): - load_net_clean[k[7:]] = v - else: - load_net_clean[k] = v - pretrained_net = load_net_clean - - tbd = [] - for k, v in crt_net.items(): - tbd.append(k) - - # directly copy - for k, v in crt_net.items(): - if k in pretrained_net and pretrained_net[k].size() == v.size(): - crt_net[k] = pretrained_net[k] - tbd.remove(k) - - crt_net['conv_first.weight'] = pretrained_net['model.0.weight'] - crt_net['conv_first.bias'] = pretrained_net['model.0.bias'] - - for k in tbd.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[k] = pretrained_net[ori_k] - tbd.remove(k) - - crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight'] - crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias'] - crt_net['upconv1.weight'] = pretrained_net['model.3.weight'] - crt_net['upconv1.bias'] = pretrained_net['model.3.bias'] - crt_net['upconv2.weight'] = pretrained_net['model.6.weight'] - crt_net['upconv2.bias'] = pretrained_net['model.6.bias'] - crt_net['HRconv.weight'] = pretrained_net['model.8.weight'] - crt_net['HRconv.bias'] = pretrained_net['model.8.bias'] - crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] - crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] - - return crt_net class UpscalerESRGAN(Upscaler): def __init__(self, dirname): @@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler): print("Unable to load %s from %s" % (self.model_path, filename)) return None - pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) - crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) + state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) + + if "params_ema" in state_dict: + state_dict = state_dict["params_ema"] + elif "params" in state_dict: + state_dict = state_dict["params"] + num_conv = 16 if "realesr-animevideov3" in filename else 32 + model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu') + model.load_state_dict(state_dict) + model.eval() + return model + + if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict: + nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23 + state_dict = resrgan2normal(state_dict, nb) + elif "conv_first.weight" in state_dict: + state_dict = mod2normal(state_dict) + elif "model.0.weight" not in state_dict: + raise Exception("The file is not a recognized ESRGAN model.") + + in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) - pretrained_net = fix_model_layers(crt_model, pretrained_net) - crt_model.load_state_dict(pretrained_net) - crt_model.eval() + model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus) + model.load_state_dict(state_dict) + model.eval() - return crt_model + return model def upscale_without_tiling(model, img): img = np.array(img) img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() img = img.unsqueeze(0).to(devices.device_esrgan) with torch.no_grad(): diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py new file mode 100644 index 00000000..bc9ceb2a --- /dev/null +++ b/modules/esrgan_model_arch.py @@ -0,0 +1,463 @@ +# this file is adapted from https://github.com/victorca25/iNNfer + +import math +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F + + +#################### +# RRDBNet Generator +#################### + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, + act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', + finalact=None, gaussian_noise=False, plus=False): + super(RRDBNet, self).__init__() + n_upscale = int(math.log(upscale, 2)) + if upscale == 3: + n_upscale = 1 + + self.resrgan_scale = 0 + if in_nc % 16 == 0: + self.resrgan_scale = 1 + elif in_nc != 4 and in_nc % 4 == 0: + self.resrgan_scale = 2 + + fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) + rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', + norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, + gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)] + LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype) + + if upsample_mode == 'upconv': + upsample_block = upconv_block + elif upsample_mode == 'pixelshuffle': + upsample_block = pixelshuffle_block + else: + raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + if upscale == 3: + upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype) + else: + upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)] + HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype) + HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) + + outact = act(finalact) if finalact else None + + self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)), + *upsampler, HR_conv0, HR_conv1, outact) + + def forward(self, x, outm=None): + if self.resrgan_scale == 1: + feat = pixel_unshuffle(x, scale=4) + elif self.resrgan_scale == 2: + feat = pixel_unshuffle(x, scale=2) + else: + feat = x + + return self.model(feat) + + +class RRDB(nn.Module): + """ + Residual in Residual Dense Block + (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) + """ + + def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', + norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', + spectral_norm=False, gaussian_noise=False, plus=False): + super(RRDB, self).__init__() + # This is for backwards compatibility with existing models + if nr == 3: + self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) + self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) + self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) + else: + RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)] + self.RDBs = nn.Sequential(*RDB_list) + + def forward(self, x): + if hasattr(self, 'RDB1'): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + else: + out = self.RDBs(x) + return out * 0.2 + x + + +class ResidualDenseBlock_5C(nn.Module): + """ + Residual Dense Block + The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) + Modified options that can be used: + - "Partial Convolution based Padding" arXiv:1811.11718 + - "Spectral normalization" arXiv:1802.05957 + - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. + {Rakotonirina} and A. {Rasoanaivo} + """ + + def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', + norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', + spectral_norm=False, gaussian_noise=False, plus=False): + super(ResidualDenseBlock_5C, self).__init__() + + self.noise = GaussianNoise() if gaussian_noise else None + self.conv1x1 = conv1x1(nf, gc) if plus else None + + self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + if mode == 'CNA': + last_act = None + else: + last_act = act_type + self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(torch.cat((x, x1), 1)) + if self.conv1x1: + x2 = x2 + self.conv1x1(x) + x3 = self.conv3(torch.cat((x, x1, x2), 1)) + x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) + if self.conv1x1: + x4 = x4 + x2 + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + if self.noise: + return self.noise(x5.mul(0.2) + x) + else: + return x5 * 0.2 + x + + +#################### +# ESRGANplus +#################### + +class GaussianNoise(nn.Module): + def __init__(self, sigma=0.1, is_relative_detach=False): + super().__init__() + self.sigma = sigma + self.is_relative_detach = is_relative_detach + self.noise = torch.tensor(0, dtype=torch.float) + + def forward(self, x): + if self.training and self.sigma != 0: + self.noise = self.noise.to(x.device) + scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x + sampled_noise = self.noise.repeat(*x.size()).normal_() * scale + x = x + sampled_noise + return x + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +#################### +# SRVGGNetCompact +#################### + +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + This class is copied from https://github.com/xinntao/Real-ESRGAN + """ + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): + super(SRVGGNetCompact, self).__init__() + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_conv = num_conv + self.upscale = upscale + self.act_type = act_type + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) + # the first activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the body structure + for _ in range(num_conv): + self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) + # activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(upscale) + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') + out += base + return out + + +#################### +# Upsampler +#################### + +class Upsample(nn.Module): + r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. + The input data is assumed to be of the form + `minibatch x channels x [optional depth] x [optional height] x width`. + """ + + def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): + super(Upsample, self).__init__() + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.size = size + self.align_corners = align_corners + + def forward(self, x): + return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + + def extra_repr(self): + if self.scale_factor is not None: + info = 'scale_factor=' + str(self.scale_factor) + else: + info = 'size=' + str(self.size) + info += ', mode=' + self.mode + return info + + +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, + pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'): + """ + Pixel shuffle layer + (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional + Neural Network, CVPR17) + """ + conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, + pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + + n = norm(norm_type, out_nc) if norm_type else None + a = act(act_type) if act_type else None + return sequential(conv, pixel_shuffle, n, a) + + +def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, + pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'): + """ Upconv layer """ + upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor + upsample = Upsample(scale_factor=upscale_factor, mode=mode) + conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, + pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype) + return sequential(upsample, conv) + + + + + + + + +#################### +# Basic blocks +#################### + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + Args: + basic_block (nn.module): nn.module class for basic block. (block) + num_basic_block (int): number of blocks. (n_layers) + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): + """ activation helper """ + act_type = act_type.lower() + if act_type == 'relu': + layer = nn.ReLU(inplace) + elif act_type in ('leakyrelu', 'lrelu'): + layer = nn.LeakyReLU(neg_slope, inplace) + elif act_type == 'prelu': + layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) + elif act_type == 'tanh': # [-1, 1] range output + layer = nn.Tanh() + elif act_type == 'sigmoid': # [0, 1] range output + layer = nn.Sigmoid() + else: + raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) + return layer + + +class Identity(nn.Module): + def __init__(self, *kwargs): + super(Identity, self).__init__() + + def forward(self, x, *kwargs): + return x + + +def norm(norm_type, nc): + """ Return a normalization layer """ + norm_type = norm_type.lower() + if norm_type == 'batch': + layer = nn.BatchNorm2d(nc, affine=True) + elif norm_type == 'instance': + layer = nn.InstanceNorm2d(nc, affine=False) + elif norm_type == 'none': + def norm_layer(x): return Identity() + else: + raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type)) + return layer + + +def pad(pad_type, padding): + """ padding layer helper """ + pad_type = pad_type.lower() + if padding == 0: + return None + if pad_type == 'reflect': + layer = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + layer = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + layer = nn.ZeroPad2d(padding) + else: + raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type)) + return layer + + +def get_valid_padding(kernel_size, dilation): + kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) + padding = (kernel_size - 1) // 2 + return padding + + +class ShortcutBlock(nn.Module): + """ Elementwise sum the output of a submodule to its input """ + def __init__(self, submodule): + super(ShortcutBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = x + self.sub(x) + return output + + def __repr__(self): + return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|') + + +def sequential(*args): + """ Flatten Sequential. It unwraps nn.Sequential. """ + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError('sequential does not support OrderedDict input.') + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, + pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', + spectral_norm=False): + """ Conv layer with padding, normalization, activation """ + assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode) + padding = get_valid_padding(kernel_size, dilation) + p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None + padding = padding if pad_type == 'zero' else 0 + + if convtype=='PartialConv2D': + c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + elif convtype=='DeformConv2D': + c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + elif convtype=='Conv3D': + c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + else: + c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + + if spectral_norm: + c = nn.utils.spectral_norm(c) + + a = act(act_type) if act_type else None + if 'CNA' in mode: + n = norm(norm_type, out_nc) if norm_type else None + return sequential(p, c, n, a) + elif mode == 'NAC': + if norm_type is None and act_type is not None: + a = act(act_type, inplace=False) + n = norm(norm_type, in_nc) if norm_type else None + return sequential(n, a, p, c) -- cgit v1.2.3 From bb57f30c2de46cfca5419ad01738a41705f96cc3 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Fri, 14 Oct 2022 10:56:41 +0200 Subject: init --- README.md | 1 + aesthetic_embeddings/insert_embs_here.txt | 0 modules/processing.py | 17 +++++- modules/sd_hijack.py | 80 +++++++++++++++++++++++++- modules/shared.py | 5 ++ modules/textual_inversion/dataset.py | 2 +- modules/textual_inversion/textual_inversion.py | 35 +++++++---- modules/txt2img.py | 11 +++- modules/ui.py | 59 ++++++++++++------- 9 files changed, 172 insertions(+), 38 deletions(-) create mode 100644 aesthetic_embeddings/insert_embs_here.txt (limited to 'modules') diff --git a/README.md b/README.md index 859a91b6..7b8d018b 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) +- Aesthetic, a way to generate images with a specific aesthetic by using clip images embds (implementation of https://github.com/vicgalle/stable-diffusion-aesthetic-gradients) ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. diff --git a/aesthetic_embeddings/insert_embs_here.txt b/aesthetic_embeddings/insert_embs_here.txt new file mode 100644 index 00000000..e69de29b diff --git a/modules/processing.py b/modules/processing.py index d5172f00..9a033759 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -316,11 +316,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() -def process_images(p: StableDiffusionProcessing) -> Processed: +def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, + aesthetic_imgs=None,aesthetic_slerp=False) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" + aesthetic_lr = float(aesthetic_lr) + aesthetic_weight = float(aesthetic_weight) + aesthetic_steps = int(aesthetic_steps) + if type(p.prompt) == list: - assert(len(p.prompt) > 0) + assert (len(p.prompt) > 0) else: assert p.prompt is not None @@ -394,7 +399,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) + if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): + shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0) + uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], + p.steps) + if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): + shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight, + aesthetic_steps, aesthetic_imgs,aesthetic_slerp) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c81722a0..6d5196fe 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -9,11 +9,14 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared -from modules.shared import opts, device, cmd_opts +from modules.shared import opts, device, cmd_opts, aesthetic_embeddings from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model +from transformers import CLIPVisionModel, CLIPModel +import torch.optim as optim +import copy attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity @@ -109,13 +112,29 @@ class StableDiffusionModelHijack: _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) +def slerp(low, high, val): + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm*high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped + self.clipModel = CLIPModel.from_pretrained( + self.wrapped.transformer.name_or_path + ) + del self.clipModel.vision_model self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer + # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() + self.image_embs_name = None + self.image_embs = None + self.load_image_embs(None) + self.token_mults = {} self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] @@ -136,6 +155,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult + def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None, + aesthetic_slerp=True): + self.slerp = aesthetic_slerp + self.aesthetic_lr = aesthetic_lr + self.aesthetic_weight = aesthetic_weight + self.aesthetic_steps = aesthetic_steps + self.load_image_embs(image_embs_name) + + def load_image_embs(self, image_embs_name): + if image_embs_name is None or len(image_embs_name) == 0: + image_embs_name = None + if image_embs_name is not None and self.image_embs_name != image_embs_name: + self.image_embs_name = image_embs_name + self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device) + self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) + self.image_embs.requires_grad_(False) + def tokenize_line(self, line, used_custom_terms, hijack_comments): id_end = self.wrapped.tokenizer.eos_token_id @@ -333,7 +369,47 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - + + if len(text[ + 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [ + [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in + remade_batch_tokens] + + tokens = torch.asarray(remade_batch_tokens).to(device) + with torch.enable_grad(): + model = copy.deepcopy(self.clipModel).to(device) + model.requires_grad_(True) + + # We optimize the model to maximize the similarity + optimizer = optim.Adam( + model.text_model.parameters(), lr=self.aesthetic_lr + ) + + for i in range(self.aesthetic_steps): + text_embs = model.get_text_features(input_ids=tokens) + text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) + sim = text_embs @ self.image_embs.T + loss = -sim + optimizer.zero_grad() + loss.mean().backward() + optimizer.step() + + zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + if opts.CLIP_stop_at_last_layers > 1: + zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] + zn = model.text_model.final_layer_norm(zn) + else: + zn = zn.last_hidden_state + model.cpu() + del model + + if self.slerp: + z = slerp(z, zn, self.aesthetic_weight) + else: + z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight + remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 diff --git a/modules/shared.py b/modules/shared.py index 5901e605..cf13a10d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -30,6 +30,8 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(script_path, 'aesthetic_embeddings'), + help="aesthetic_embeddings directory(default: aesthetic_embeddings)") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") @@ -90,6 +92,9 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None +aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in + os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} + def reload_hypernetworks(): global hypernetworks diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 67e90afe..59b2b021 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -48,7 +48,7 @@ class PersonalizedBase(Dataset): print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) + image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC) except Exception: continue diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fa0e33a2..b12a8e6d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -172,7 +172,15 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): +def batched(dataset, total, n=1): + for ndx in range(0, total, n): + yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] + + +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, + create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, + preview_image_prompt, batch_size=1, + gradient_accumulation=1): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -204,7 +212,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, + height=training_height, + repeats=shared.opts.training_image_repeats_per_epoch, + placeholder_token=embedding_name, model=shared.sd_model, + device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -223,7 +235,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + pbar = tqdm.tqdm(enumerate(batched(ds, steps - ititial_step, batch_size)), total=steps - ititial_step) for i, entry in pbar: embedding.step = i + ititial_step @@ -235,17 +247,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini break with torch.autocast("cuda"): - c = cond_model([entry.cond_text]) + c = cond_model([e.cond_text for e in entry]) + + x = torch.stack([e.latent for e in entry]).to(devices.device) + loss = shared.sd_model(x, c)[0] - x = entry.latent.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] del x losses[embedding.step % losses.shape[0]] = loss.item() - optimizer.zero_grad() loss.backward() - optimizer.step() + if ((i + 1) % gradient_accumulation == 0) or (i + 1 == steps - ititial_step): + optimizer.step() + optimizer.zero_grad() + epoch_num = embedding.step // len(ds) epoch_step = embedding.step - (epoch_num * len(ds)) + 1 @@ -259,7 +274,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') - preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt + preview_text = entry[0].cond_text if preview_image_prompt == "" else preview_image_prompt p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -305,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini

Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entry[-1].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/txt2img.py b/modules/txt2img.py index e985242b..78342024 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -6,7 +6,14 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, + restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, + subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, + height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, + aesthetic_lr=0, + aesthetic_weight=0, aesthetic_steps=0, + aesthetic_imgs=None, + aesthetic_slerp=False, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -40,7 +47,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: - processed = process_images(p) + processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp) shared.total_tqdm.clear() diff --git a/modules/ui.py b/modules/ui.py index 220fb80b..d961d126 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -24,7 +24,8 @@ import gradio.routes from modules import sd_hijack from modules.paths import script_path -from modules.shared import opts, cmd_opts +from modules.shared import opts, cmd_opts,aesthetic_embeddings + if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags import modules.shared as shared @@ -534,6 +535,14 @@ def create_ui(wrap_gradio_gpu_call): width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + with gr.Group(): + aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7) + aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50) + + aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) + aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) tiling = gr.Checkbox(label='Tiling', value=False) @@ -586,25 +595,30 @@ def create_ui(wrap_gradio_gpu_call): fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), _js="submit", inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - scale_latent, - denoising_strength, - ] + custom_inputs, + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + scale_latent, + denoising_strength, + aesthetic_lr, + aesthetic_weight, + aesthetic_steps, + aesthetic_imgs, + aesthetic_slerp + ] + custom_inputs, outputs=[ txt2img_gallery, generation_info, @@ -1097,6 +1111,9 @@ def create_ui(wrap_gradio_gpu_call): template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + batch_size = gr.Slider(minimum=1, maximum=64, step=1, label="Batch Size", value=4) + gradient_accumulation = gr.Slider(minimum=1, maximum=256, step=1, label="Gradient accumulation", + value=1) steps = gr.Number(label='Max steps', value=100000, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) @@ -1180,6 +1197,8 @@ def create_ui(wrap_gradio_gpu_call): template_file, save_image_with_stored_embedding, preview_image_prompt, + batch_size, + gradient_accumulation ], outputs=[ ti_output, -- cgit v1.2.3 From 37d7ffb415cd8c69b3c0bb5f61844dde0b169f78 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 15:59:37 +0200 Subject: fix to tokens lenght, addend embs generator, add new features to edit the embedding before the generation using text --- modules/aesthetic_clip.py | 78 ++++++++++++++++++++++++ modules/processing.py | 148 +++++++++++++++++++++++++++++++--------------- modules/sd_hijack.py | 111 ++++++++++++++++++++++------------ modules/shared.py | 4 ++ modules/txt2img.py | 10 +++- modules/ui.py | 47 ++++++++++++--- 6 files changed, 302 insertions(+), 96 deletions(-) create mode 100644 modules/aesthetic_clip.py (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py new file mode 100644 index 00000000..f15cfd47 --- /dev/null +++ b/modules/aesthetic_clip.py @@ -0,0 +1,78 @@ +import itertools +import os +from pathlib import Path +import html +import gc + +import gradio as gr +import torch +from PIL import Image +from modules import shared +from modules.shared import device, aesthetic_embeddings +from transformers import CLIPModel, CLIPProcessor + +from tqdm.auto import tqdm + + +def get_all_images_in_folder(folder): + return [os.path.join(folder, f) for f in os.listdir(folder) if + os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)] + + +def check_is_valid_image_file(filename): + return filename.lower().endswith(('.png', '.jpg', '.jpeg')) + + +def batched(dataset, total, n=1): + for ndx in range(0, total, n): + yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] + + +def iter_to_batched(iterable, n=1): + it = iter(iterable) + while True: + chunk = tuple(itertools.islice(it, n)) + if not chunk: + return + yield chunk + + +def generate_imgs_embd(name, folder, batch_size): + # clipModel = CLIPModel.from_pretrained( + # shared.sd_model.cond_stage_model.clipModel.name_or_path + # ) + model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device) + processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path) + + with torch.no_grad(): + embs = [] + for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size), + desc=f"Generating embeddings for {name}"): + if shared.state.interrupted: + break + inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device) + outputs = model.get_image_features(**inputs).cpu() + embs.append(torch.clone(outputs)) + inputs.to("cpu") + del inputs, outputs + + embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) + + # The generated embedding will be located here + path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") + torch.save(embs, path) + + model = model.cpu() + del model + del processor + del embs + gc.collect() + torch.cuda.empty_cache() + res = f""" + Done generating embedding for {name}! + Hypernetwork saved to {html.escape(path)} + """ + shared.update_aesthetic_embeddings() + return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", + value=sorted(aesthetic_embeddings.keys())[0] if len( + aesthetic_embeddings) > 0 else None), res, "" diff --git a/modules/processing.py b/modules/processing.py index 9a033759..ab68d63a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -20,7 +20,6 @@ import modules.images as images import modules.styles import logging - # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 opt_f = 8 @@ -52,8 +51,13 @@ def get_correct_sampler(p): elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): return sd_samplers.samplers_for_img2img + class StableDiffusionProcessing: - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, + subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, + sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, + restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, + extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -104,7 +108,8 @@ class StableDiffusionProcessing: class Processed: - def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): + def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, + all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): self.images = images_list self.prompt = p.prompt self.negative_prompt = p.negative_prompt @@ -141,7 +146,8 @@ class Processed: self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) - self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 + self.subseed = int( + self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.all_prompts = all_prompts or [self.prompt] self.all_seeds = all_seeds or [self.seed] @@ -181,39 +187,43 @@ class Processed: return json.dumps(obj) - def infotext(self, p: StableDiffusionProcessing, index): - return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size) + def infotext(self, p: StableDiffusionProcessing, index): + return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], + position_in_batch=index % self.batch_size, iteration=index // self.batch_size) # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 def slerp(val, low, high): - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) - dot = (low_norm*high_norm).sum(1) + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + dot = (low_norm * high_norm).sum(1) if dot.mean() > 0.9995: return low * val + high * (1 - val) omega = torch.acos(dot) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res -def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): +def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, + p=None): xs = [] # if we have multiple seeds, this means we are working with batch size>1; this then # enables the generation of additional tensors with noise that the sampler will use during its processing. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): + if p is not None and p.sampler is not None and ( + len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] else: sampler_noises = None for i, seed in enumerate(seeds): - noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) + noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else ( + shape[0], seed_resize_from_h // 8, seed_resize_from_w // 8) subnoise = None if subseeds is not None: @@ -241,7 +251,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see dx = max(-dx, 0) dy = max(-dy, 0) - x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w] + x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w] noise = x if sampler_noises is not None: @@ -293,14 +303,20 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", - "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')), + "Model hash": getattr(p, 'sd_model_hash', + None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), + "Model": ( + None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace( + ',', '').replace(':', '')), + "Hypernet": ( + None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace( + ':', '')), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), - "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), + "Seed resize from": ( + None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, @@ -309,7 +325,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params.update(p.extra_generation_params) - generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join( + [k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" @@ -317,7 +334,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, - aesthetic_imgs=None,aesthetic_slerp=False) -> Processed: + aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" aesthetic_lr = float(aesthetic_lr) @@ -385,7 +404,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh for n in range(p.n_iter): if state.skipped: state.skipped = False - + if state.interrupted: break @@ -396,16 +415,19 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh if (len(prompts) == 0): break - #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) - #c = p.sd_model.get_learned_conditioning(prompts) + # uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) + # c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): - shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0) + shared.sd_model.cond_stage_model.set_aesthetic_params() uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight, - aesthetic_steps, aesthetic_imgs,aesthetic_slerp) + aesthetic_steps, aesthetic_imgs, + aesthetic_slerp, aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: @@ -413,13 +435,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh comments[comment] = 1 if p.n_iter > 1: - shared.state.job = f"Batch {n+1} out of {p.n_iter}" + shared.state.job = f"Batch {n + 1} out of {p.n_iter}" with devices.autocast(): - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, + subseed_strength=p.subseed_strength) if state.interrupted or state.skipped: - # if we are interrupted, sample returns just noise # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent @@ -445,7 +467,9 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh if p.restore_faces: if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: - images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") + images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], + opts.samples_format, info=infotext(n, i), p=p, + suffix="-before-face-restoration") devices.torch_gc() @@ -456,7 +480,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: - images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") + images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, + info=infotext(n, i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) if p.overlay_images is not None and i < len(p.overlay_images): @@ -474,7 +499,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh image = image.convert('RGB') if opts.samples_save and not p.do_not_save_samples: - images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) + images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, + info=infotext(n, i), p=p) text = infotext(n, i) infotexts.append(text) @@ -482,7 +508,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh image.info["parameters"] = text output_images.append(image) - del x_samples_ddim + del x_samples_ddim devices.torch_gc() @@ -504,10 +530,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh index_of_first_image = 1 if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, + info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) devices.torch_gc() - return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), + subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, + index_of_first_image=index_of_first_image, infotexts=infotexts) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): @@ -543,25 +572,34 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, + subseeds=subseeds, subseed_strength=self.subseed_strength, + seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) return samples - x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, + subseeds=subseeds, subseed_strength=self.subseed_strength, + seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f - samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2] + samples = samples[:, :, truncate_y // 2:samples.shape[2] - truncate_y // 2, + truncate_x // 2:samples.shape[3] - truncate_x // 2] if self.scale_latent: - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), + mode="bilinear") else: decoded_samples = decode_first_stage(self.sd_model, samples) if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": - decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") + decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), + mode="bilinear") else: lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) @@ -585,13 +623,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) - noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, + subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, + steps=self.steps) return samples @@ -599,7 +640,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None - def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs): + def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, + inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, + **kwargs): super().__init__(**kwargs) self.init_images = init_images @@ -607,7 +650,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.denoising_strength: float = denoising_strength self.init_latent = None self.image_mask = mask - #self.image_unblurred_mask = None + # self.image_unblurred_mask = None self.latent_mask = None self.mask_for_overlay = None self.mask_blur = mask_blur @@ -619,7 +662,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.nmask = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, + self.sd_model) crop_region = None if self.image_mask is not None: @@ -628,7 +672,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpainting_mask_invert: self.image_mask = ImageOps.invert(self.image_mask) - #self.image_unblurred_mask = self.image_mask + # self.image_unblurred_mask = self.image_mask if self.mask_blur > 0: self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) @@ -642,7 +686,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask = mask.crop(crop_region) self.image_mask = images.resize_image(2, mask, self.width, self.height) - self.paste_to = (x1, y1, x2-x1, y2-y1) + self.paste_to = (x1, y1, x2 - x1, y2 - y1) else: self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height) np_mask = np.array(self.image_mask) @@ -665,7 +709,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.image_mask is not None: image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + image_masked.paste(image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) self.overlay_images.append(image_masked.convert('RGBA')) @@ -714,12 +759,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # this needs to be fixed to be done in sample() using actual seeds for batches if self.inpainting_fill == 2: - self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask + self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], + all_seeds[ + 0:self.init_latent.shape[ + 0]]) * self.nmask elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, + subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6d5196fe..192883b2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -14,7 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model -from transformers import CLIPVisionModel, CLIPModel +from tqdm import trange +from transformers import CLIPVisionModel, CLIPModel, CLIPTokenizer import torch.optim as optim import copy @@ -22,21 +23,25 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward + def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and ( + 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): + elif not cmd_opts.disable_opt_split_attention and ( + cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): if not invokeAI_mps_available and shared.device.type == 'mps': - print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") + print( + "The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 else: @@ -112,14 +117,16 @@ class StableDiffusionModelHijack: _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) + def slerp(low, high, val): - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm*high_norm).sum(1)) + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm * high_norm).sum(1)) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() @@ -128,6 +135,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.wrapped.transformer.name_or_path ) del self.clipModel.vision_model + self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path) self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() @@ -139,7 +147,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] - tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if + '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 for c in text: @@ -155,8 +164,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None, - aesthetic_slerp=True): + def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, + aesthetic_slerp=True, aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False): + self.aesthetic_imgs_text = aesthetic_imgs_text + self.aesthetic_slerp_angle = aesthetic_slerp_angle + self.aesthetic_text_negative = aesthetic_text_negative self.slerp = aesthetic_slerp self.aesthetic_lr = aesthetic_lr self.aesthetic_weight = aesthetic_weight @@ -180,7 +194,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: parsed = [[line, 1.0]] - tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] + tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)[ + "input_ids"] fixes = [] remade_tokens = [] @@ -196,18 +211,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if token == self.comma_token: last_comma = len(remade_tokens) - elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: + elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), + 1) % 75 == 0 and last_comma != -1 and len( + remade_tokens) - last_comma <= opts.comma_padding_backtrack: last_comma += 1 reloc_tokens = remade_tokens[last_comma:] reloc_mults = multipliers[last_comma:] remade_tokens = remade_tokens[:last_comma] length = len(remade_tokens) - + rem = int(math.ceil(length / 75)) * 75 - length remade_tokens += [id_end] * rem + reloc_tokens multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults - + if embedding is None: remade_tokens.append(token) multipliers.append(weight) @@ -248,7 +265,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if line in cache: remade_tokens, fixes, multipliers = cache[line] else: - remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, + hijack_comments) token_count = max(current_token_count, token_count) cache[line] = (remade_tokens, fixes, multipliers) @@ -259,7 +277,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id @@ -289,7 +306,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, + i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: @@ -312,11 +330,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ovf = remade_tokens[maxlen - 2:] overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + hijack_comments.append( + f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) @@ -326,23 +345,26 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): hijack_fixes.append(fixes) batch_multipliers.append(multipliers) return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - + def forward(self, text): use_old = opts.use_old_emphasis_implementation if use_old: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old( + text) else: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text( + text) self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - + self.hijack.comments.append( + "Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + if use_old: self.hijack.fixes = hijack_fixes return self.process_tokens(remade_batch_tokens, batch_multipliers) - + z = None i = 0 while max(map(len, remade_batch_tokens)) != 0: @@ -356,7 +378,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if fix[0] == i: fixes.append(fix[1]) self.hijack.fixes.append(fixes) - + tokens = [] multipliers = [] for j in range(len(remade_batch_tokens)): @@ -378,19 +400,30 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens).to(device) + + model = copy.deepcopy(self.clipModel).to(device) + model.requires_grad_(True) + if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: + text_embs_2 = model.get_text_features( + **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) + if self.aesthetic_text_negative: + text_embs_2 = self.image_embs - text_embs_2 + text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) + img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) + else: + img_embs = self.image_embs + with torch.enable_grad(): - model = copy.deepcopy(self.clipModel).to(device) - model.requires_grad_(True) # We optimize the model to maximize the similarity optimizer = optim.Adam( model.text_model.parameters(), lr=self.aesthetic_lr ) - for i in range(self.aesthetic_steps): + for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"): text_embs = model.get_text_features(input_ids=tokens) text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ self.image_embs.T + sim = text_embs @ img_embs.T loss = -sim optimizer.zero_grad() loss.mean().backward() @@ -405,6 +438,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): model.cpu() del model + zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1) if self.slerp: z = slerp(z, zn, self.aesthetic_weight) else: @@ -413,15 +447,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 - + return z - - + def process_tokens(self, remade_batch_tokens, batch_multipliers): if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] + remade_batch_tokens = [ + [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in + remade_batch_tokens] batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] - + tokens = torch.asarray(remade_batch_tokens).to(device) outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) @@ -461,8 +496,8 @@ class EmbeddingsWithFixes(torch.nn.Module): for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: emb = embedding.vec - emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) - tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) + emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) vecs.append(tensor) diff --git a/modules/shared.py b/modules/shared.py index cf13a10d..7cd608ca 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -95,6 +95,10 @@ loaded_hypernetwork = None aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} +def update_aesthetic_embeddings(): + global aesthetic_embeddings + aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in + os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} def reload_hypernetworks(): global hypernetworks diff --git a/modules/txt2img.py b/modules/txt2img.py index 78342024..eedcdfe0 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -13,7 +13,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, - aesthetic_slerp=False, *args): + aesthetic_slerp=False, + aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False, + *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -47,7 +51,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: - processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp) + processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative) shared.total_tqdm.clear() diff --git a/modules/ui.py b/modules/ui.py index d961d126..e98e2113 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -41,6 +41,7 @@ from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui +import modules.aesthetic_clip # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -449,7 +450,7 @@ def create_toprow(is_img2img): with gr.Row(): negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) with gr.Column(scale=1, elem_id="roll_col"): - sh = gr.Button(elem_id="sh", visible=True) + sh = gr.Button(elem_id="sh", visible=True) with gr.Column(scale=1, elem_id="style_neg_col"): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) @@ -536,9 +537,13 @@ def create_ui(wrap_gradio_gpu_call): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with gr.Group(): - aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50) + aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001") + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) + aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5) + with gr.Row(): + aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") + aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) + aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) @@ -617,7 +622,10 @@ def create_ui(wrap_gradio_gpu_call): aesthetic_weight, aesthetic_steps, aesthetic_imgs, - aesthetic_slerp + aesthetic_slerp, + aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative ] + custom_inputs, outputs=[ txt2img_gallery, @@ -721,7 +729,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) - inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) + inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=1024, step=4, value=32) with gr.TabItem('Batch img2img', id='batch'): hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' @@ -1071,6 +1079,17 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create embedding", variant='primary') + with gr.Tab(label="Create images embedding"): + new_embedding_name_ae = gr.Textbox(label="Name") + process_src_ae = gr.Textbox(label='Source directory') + batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256) + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding_ae = gr.Button(value="Create images embedding", variant='primary') + with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) @@ -1139,7 +1158,7 @@ def create_ui(wrap_gradio_gpu_call): fn=modules.textual_inversion.ui.create_embedding, inputs=[ new_embedding_name, - initialization_text, + process_src, nvpt, ], outputs=[ @@ -1149,6 +1168,20 @@ def create_ui(wrap_gradio_gpu_call): ] ) + create_embedding_ae.click( + fn=modules.aesthetic_clip.generate_imgs_embd, + inputs=[ + new_embedding_name_ae, + process_src_ae, + batch_ae + ], + outputs=[ + aesthetic_imgs, + ti_output, + ti_outcome, + ] + ) + create_hypernetwork.click( fn=modules.hypernetworks.ui.create_hypernetwork, inputs=[ -- cgit v1.2.3 From 4387e4fe6479c08f7bc7e42924c3a1093e3a1872 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 18:39:29 +0200 Subject: Update modules/ui.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Gallego --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d0696101..5bb961b2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -599,7 +599,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001") aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5) + aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) + with gr.Row(): aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) -- cgit v1.2.3 From 9b7705e0573bddde26df4575c71f994d73a4d519 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 18:40:34 +0200 Subject: Update modules/aesthetic_clip.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Gallego --- modules/aesthetic_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index f15cfd47..bcf2b073 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -70,7 +70,7 @@ def generate_imgs_embd(name, folder, batch_size): torch.cuda.empty_cache() res = f""" Done generating embedding for {name}! - Hypernetwork saved to {html.escape(path)} + Aesthetic embedding saved to {html.escape(path)} """ shared.update_aesthetic_embeddings() return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", -- cgit v1.2.3 From 0d4f5db235357aeb4c7a8738179ba33aaf5a6b75 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 18:40:58 +0200 Subject: Update modules/ui.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Gallego --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 5bb961b2..25eba548 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -597,7 +597,8 @@ def create_ui(wrap_gradio_gpu_call): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with gr.Group(): - aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001") + aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) -- cgit v1.2.3 From ad9bc604a8fadcfebe72be37f66cec51e7e87fb5 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 18:41:18 +0200 Subject: Update modules/ui.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Gallego --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 25eba548..3b28b69c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -607,7 +607,8 @@ def create_ui(wrap_gradio_gpu_call): aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) - aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) + aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Aesthetic imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) + aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) with gr.Row(): -- cgit v1.2.3 From 3f5c3b981e46c16bb10948d012575b25170efb3b Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 18:41:46 +0200 Subject: Update modules/ui.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Víctor Gallego --- modules/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 3b28b69c..1f6fcdc9 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1190,7 +1190,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create embedding", variant='primary') - with gr.Tab(label="Create images embedding"): + with gr.Tab(label="Create aesthetic images embedding"): + new_embedding_name_ae = gr.Textbox(label="Name") process_src_ae = gr.Textbox(label='Source directory') batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256) -- cgit v1.2.3 From 3d21684ee30ca5734126b8d08c05b3a0f513fe75 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 00:01:00 +0200 Subject: Add support to other img format, fixed dropbox update --- modules/aesthetic_clip.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index bcf2b073..68264284 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -8,7 +8,7 @@ import gradio as gr import torch from PIL import Image from modules import shared -from modules.shared import device, aesthetic_embeddings +from modules.shared import device from transformers import CLIPModel, CLIPProcessor from tqdm.auto import tqdm @@ -20,7 +20,7 @@ def get_all_images_in_folder(folder): def check_is_valid_image_file(filename): - return filename.lower().endswith(('.png', '.jpg', '.jpeg')) + return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp")) def batched(dataset, total, n=1): @@ -73,6 +73,6 @@ def generate_imgs_embd(name, folder, batch_size): Aesthetic embedding saved to {html.escape(path)} """ shared.update_aesthetic_embeddings() - return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", - value=sorted(aesthetic_embeddings.keys())[0] if len( - aesthetic_embeddings) > 0 else None), res, "" + return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", + value=sorted(shared.aesthetic_embeddings.keys())[0] if len( + shared.aesthetic_embeddings) > 0 else None), res, "" -- cgit v1.2.3 From 9325c85f780c569d1823e422eaf51b2e497e0d3e Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 00:23:47 +0200 Subject: fixed dropbox update --- modules/sd_hijack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 192883b2..491312b4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -9,7 +9,7 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared -from modules.shared import opts, device, cmd_opts, aesthetic_embeddings +from modules.shared import opts, device, cmd_opts from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention @@ -182,7 +182,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): image_embs_name = None if image_embs_name is not None and self.image_embs_name != image_embs_name: self.image_embs_name = image_embs_name - self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device) + self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) self.image_embs.requires_grad_(False) -- cgit v1.2.3 From 763b893f319cee280b86e63025eb55e7c16b02e7 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sun, 16 Oct 2022 10:03:09 +0800 Subject: images history sorting files by date --- javascript/images_history.js | 12 +- modules/images_history.py | 261 ++++++++++++++++++++++++++++++++----------- 2 files changed, 202 insertions(+), 71 deletions(-) (limited to 'modules') diff --git a/javascript/images_history.js b/javascript/images_history.js index 7f0d8f42..ac5834c7 100644 --- a/javascript/images_history.js +++ b/javascript/images_history.js @@ -88,10 +88,10 @@ function images_history_set_image_info(button){ } -function images_history_get_current_img(tabname, image_path, files){ +function images_history_get_current_img(tabname, img_index, files){ return [ - gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"), - image_path, + tabname, + gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"), files ]; } @@ -129,7 +129,7 @@ function images_history_delete(del_num, tabname, img_file_name, page_index, file setTimeout(function(btn){btn.click()}, 30, btn); } images_history_disabled_del(); - return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index]; + return [del_num, tabname, img_file_name, page_index, filenames, image_index]; } function images_history_turnpage(img_path, page_index, image_index, tabname, date_from, date_to){ @@ -170,8 +170,8 @@ function images_history_init(){ } tabs_box.classList.add(images_history_tab_list[0]); - // same as above, at page load - //load_txt2img_button.click(); + // same as above, at page load-- load very fast now + load_txt2img_button.click(); } else { setTimeout(images_history_init, 500); } diff --git a/modules/images_history.py b/modules/images_history.py index f5ef44fe..533cf51b 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -1,33 +1,74 @@ import os import shutil +import time +import hashlib +import gradio +show_max_dates_num = 3 +system_bak_path = "webui_log_and_bak" +def is_valid_date(date): + try: + time.strptime(date, "%Y%m%d") + return True + except: + return False +def reduplicative_file_move(src, dst): + def same_name_file(basename, path): + name, ext = os.path.splitext(basename) + f_list = os.listdir(path) + max_num = 0 + for f in f_list: + if len(f) <= len(basename): + continue + f_ext = f[-len(ext):] if len(ext) > 0 else "" + if f[:len(name)] == name and f_ext == ext: + if f[len(name)] == "(" and f[-len(ext)-1] == ")": + number = f[len(name)+1:-len(ext)-1] + if number.isdigit(): + if int(number) > max_num: + max_num = int(number) + return f"{name}({max_num + 1}){ext}" + name = os.path.basename(src) + save_name = os.path.join(dst, name) + if not os.path.exists(save_name): + shutil.move(src, dst) + else: + name = same_name_file(name, dst) + shutil.move(src, os.path.join(dst, name)) -def traverse_all_files(output_dir, image_list, curr_dir=None): - curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir) +def traverse_all_files(curr_path, image_list, all_type=False): try: f_list = os.listdir(curr_path) except: - if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt": - image_list.append(curr_dir) + if all_type or curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt": + image_list.append(curr_path) return image_list for file in f_list: - file = file if curr_dir is None else os.path.join(curr_dir, file) - file_path = os.path.join(curr_path, file) - if file[-4:] == ".txt": + file = os.path.join(curr_path, file) + if (not all_type) and file[-4:] == ".txt": pass - elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0: + elif os.path.isfile(file) and file[-10:].rfind(".") > 0: image_list.append(file) else: - image_list = traverse_all_files(output_dir, image_list, file) + image_list = traverse_all_files(file, image_list) return image_list - -def get_recent_images(dir_name, page_index, step, image_index, tabname): - page_index = int(page_index) - f_list = os.listdir(dir_name) +def get_recent_images(dir_name, page_index, step, image_index, tabname, date_from, date_to): + #print(f"turn_page {page_index}",date_from) + if date_from is None or date_from == "": + return None, 1, None, "" image_list = [] - image_list = traverse_all_files(dir_name, image_list) - image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) + date_list = auto_sorting(dir_name) + page_index = int(page_index) + today = time.strftime("%Y%m%d",time.localtime(time.time())) + for date in date_list: + if date >= date_from and date <= date_to: + path = os.path.join(dir_name, date) + if date == today and not os.path.exists(path): + continue + image_list = traverse_all_files(path, image_list) + + image_list = sorted(image_list, key=lambda file: -os.path.getctime(file)) num = 48 if tabname != "extras" else 12 max_page_index = len(image_list) // num + 1 page_index = max_page_index if page_index == -1 else page_index + step @@ -38,40 +79,101 @@ def get_recent_images(dir_name, page_index, step, image_index, tabname): image_index = int(image_index) if image_index < 0 or image_index > len(image_list) - 1: current_file = None - hidden = None else: - current_file = image_list[int(image_index)] - hidden = os.path.join(dir_name, current_file) - return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, "" + current_file = image_list[image_index] + return image_list, page_index, image_list, "" +def auto_sorting(dir_name): + #print(f"auto sorting") + bak_path = os.path.join(dir_name, system_bak_path) + if not os.path.exists(bak_path): + os.mkdir(bak_path) + log_file = None + files_list = [] + f_list = os.listdir(dir_name) + for file in f_list: + if file == system_bak_path: + continue + file_path = os.path.join(dir_name, file) + if not is_valid_date(file): + if file[-10:].rfind(".") > 0: + files_list.append(file_path) + else: + files_list = traverse_all_files(file_path, files_list, all_type=True) + + for file in files_list: + date_str = time.strftime("%Y%m%d",time.localtime(os.path.getctime(file))) + file_path = os.path.dirname(file) + hash_path = hashlib.md5(file_path.encode()).hexdigest() + path = os.path.join(dir_name, date_str, hash_path) + if not os.path.exists(path): + os.makedirs(path) + if log_file is None: + log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a") + log_file.write(f"{hash_path},{file_path}\n") + reduplicative_file_move(file, path) + + date_list = [] + f_list = os.listdir(dir_name) + for f in f_list: + if is_valid_date(f): + date_list.append(f) + elif f == system_bak_path: + continue + else: + reduplicative_file_move(os.path.join(dir_name, f), bak_path) + + today = time.strftime("%Y%m%d",time.localtime(time.time())) + if today not in date_list: + date_list.append(today) + return sorted(date_list) -def first_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, 1, 0, image_index, tabname) -def end_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, -1, 0, image_index, tabname) +def archive_images(dir_name): + date_list = auto_sorting(dir_name) + date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0] + return ( + gradio.update(visible=False), + gradio.update(visible=True), + gradio.Dropdown.update(choices=date_list, value=date_list[-1]), + gradio.Dropdown.update(choices=date_list, value=date_from) + ) +def date_to_change(dir_name, page_index, image_index, tabname, date_from, date_to): + #print("date_to", date_to) + date_list = auto_sorting(dir_name) + date_from_list = [date for date in date_list if date <= date_to] + date_from = date_from_list[0] if len(date_from_list) < show_max_dates_num else date_from_list[-show_max_dates_num] + image_list, page_index, image_list, _ =get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to) + return image_list, page_index, image_list, _, gradio.Dropdown.update(choices=date_from_list, value=date_from) -def prev_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, -1, image_index, tabname) +def first_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): + return get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to) -def next_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, 1, image_index, tabname) +def end_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): + return get_recent_images(dir_name, -1, 0, image_index, tabname, date_from, date_to) -def page_index_change(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, 0, image_index, tabname) +def prev_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): + return get_recent_images(dir_name, page_index, -1, image_index, tabname, date_from, date_to) -def show_image_info(num, image_path, filenames): - # print(f"select image {num}") - file = filenames[int(num)] - return file, num, os.path.join(image_path, file) +def next_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): + return get_recent_images(dir_name, page_index, 1, image_index, tabname, date_from, date_to) + + +def page_index_change(dir_name, page_index, image_index, tabname, date_from, date_to): + return get_recent_images(dir_name, page_index, 0, image_index, tabname, date_from, date_to) -def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index): +def show_image_info(tabname_box, num, filenames): + # #print(f"select image {num}") + file = filenames[int(num)] + return file, num, file + +def delete_image(delete_num, tabname, name, page_index, filenames, image_index): if name == "": return filenames, delete_num else: @@ -81,21 +183,19 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima new_file_list = [] for name in filenames: if i >= index and i < index + delete_num: - path = os.path.join(dir_name, name) - if os.path.exists(path): - print(f"Delete file {path}") - os.remove(path) - txt_file = os.path.splitext(path)[0] + ".txt" + if os.path.exists(name): + #print(f"Delete file {name}") + os.remove(name) + txt_file = os.path.splitext(name)[0] + ".txt" if os.path.exists(txt_file): os.remove(txt_file) else: - print(f"Not exists file {path}") + #print(f"Not exists file {name}") else: new_file_list.append(name) i += 1 return new_file_list, 1 - def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): if tabname == "txt2img": dir_name = opts.outdir_txt2img_samples @@ -107,16 +207,32 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): dir_name = d[0] for p in d[1:]: dir_name = os.path.join(dir_name, p) - with gr.Row(): - renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page") - first_page = gr.Button('First Page') - prev_page = gr.Button('Prev Page') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next Page') - end_page = gr.Button('End Page') - with gr.Row(elem_id=tabname + "_images_history"): + + f_list = os.listdir(dir_name) + sorted_flag = os.path.exists(os.path.join(dir_name, system_bak_path)) or len(f_list) == 0 + date_list, date_from, date_to = None, None, None + if sorted_flag: + #print(sorted_flag) + date_list = auto_sorting(dir_name) + date_to = date_list[-1] + date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0] + + with gr.Column(visible=sorted_flag) as page_panel: with gr.Row(): + renew_page = gr.Button('Refresh', elem_id=tabname + "_images_history_renew_page", interactive=sorted_flag) + first_page = gr.Button('First Page') + prev_page = gr.Button('Prev Page') + page_index = gr.Number(value=1, label="Page Index") + next_page = gr.Button('Next Page') + end_page = gr.Button('End Page') + + with gr.Row(elem_id=tabname + "_images_history"): with gr.Column(scale=2): + with gr.Row(): + newest = gr.Button('Newest') + date_to = gr.Dropdown(choices=date_list, value=date_to, label="Date to") + date_from = gr.Dropdown(choices=date_list, value=date_from, label="Date from") + history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) with gr.Row(): delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") @@ -128,22 +244,31 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Row(): with gr.Column(): img_file_info = gr.Textbox(label="Generate Info", interactive=False) - img_file_name = gr.Textbox(label="File Name", interactive=False) - with gr.Row(): + img_file_name = gr.Textbox(value="", label="File Name", interactive=False) # hiden items + with gr.Row(visible=False): + img_path = gr.Textbox(dir_name) + tabname_box = gr.Textbox(tabname) + image_index = gr.Textbox(value=-1) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") + filenames = gr.State() + hidden = gr.Image(type="pil") + info1 = gr.Textbox() + info2 = gr.Textbox() + with gr.Column(visible=not sorted_flag) as init_warning: + with gr.Row(): + gr.Textbox("The system needs to archive the files according to the date. This requires changing the directory structure of the files", + label="Waring", + css="") + with gr.Row(): + sorted_button = gr.Button('Confirme') - img_path = gr.Textbox(dir_name.rstrip("/"), visible=False) - tabname_box = gr.Textbox(tabname, visible=False) - image_index = gr.Textbox(value=-1, visible=False) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) - filenames = gr.State() - hidden = gr.Image(type="pil", visible=False) - info1 = gr.Textbox(visible=False) - info2 = gr.Textbox(visible=False) - + + + # turn pages - gallery_inputs = [img_path, page_index, image_index, tabname_box] - gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name] + gallery_inputs = [img_path, page_index, image_index, tabname_box, date_from, date_to] + gallery_outputs = [history_gallery, page_index, filenames, img_file_name] first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) @@ -154,15 +279,21 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index]) # other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden]) + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, filenames], outputs=[img_file_name, image_index, hidden]) img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) - delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) + delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) - + date_to.change(date_to_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs + [date_from]) # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') + sorted_button.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from]) + newest.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from]) + + + + def create_history_tabs(gr, opts, run_pnginfo, switch_dict): with gr.Blocks(analytics_enabled=False) as images_history: -- cgit v1.2.3 From 523140d7805c644700009b8a2483ff4eb4a22304 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 10:23:30 +0200 Subject: ui fix --- modules/aesthetic_clip.py | 3 +-- modules/sd_hijack.py | 3 +-- modules/shared.py | 2 ++ modules/ui.py | 24 ++++++++++++++---------- 4 files changed, 18 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index 68264284..ccb35c73 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -74,5 +74,4 @@ def generate_imgs_embd(name, folder, batch_size): """ shared.update_aesthetic_embeddings() return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", - value=sorted(shared.aesthetic_embeddings.keys())[0] if len( - shared.aesthetic_embeddings) > 0 else None), res, "" + value="None"), res, "" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 01fcb78f..2de2eed5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -392,8 +392,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - if len(text[ - 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: + if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: if not opts.use_old_emphasis_implementation: remade_batch_tokens = [ [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in diff --git a/modules/shared.py b/modules/shared.py index 3c5ffef1..e2c98b2d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -96,11 +96,13 @@ loaded_hypernetwork = None aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} +aesthetic_embeddings = aesthetic_embeddings | {"None": None} def update_aesthetic_embeddings(): global aesthetic_embeddings aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} + aesthetic_embeddings = aesthetic_embeddings | {"None": None} def reload_hypernetworks(): global hypernetworks diff --git a/modules/ui.py b/modules/ui.py index 13ba3142..4069f0d2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -594,19 +594,23 @@ def create_ui(wrap_gradio_gpu_call): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with gr.Group(): - aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") - - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) + with gr.Accordion("Open for Clip Aesthetic!",open=False): + with gr.Row(): + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) + aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) - with gr.Row(): - aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") - aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) - aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + with gr.Row(): + aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") + aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), + label="Aesthetic imgs embedding", + value="None") - aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Aesthetic imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) + with gr.Row(): + aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") + aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) + aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) - aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) -- cgit v1.2.3 From e4f8b5f00dd33b7547cc6b76fbed26bb83b37a64 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 10:28:21 +0200 Subject: ui fix --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2de2eed5..5d0590af 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -178,7 +178,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.load_image_embs(image_embs_name) def load_image_embs(self, image_embs_name): - if image_embs_name is None or len(image_embs_name) == 0: + if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": image_embs_name = None if image_embs_name is not None and self.image_embs_name != image_embs_name: self.image_embs_name = image_embs_name -- cgit v1.2.3 From f62905fdf928b54aa76765e5cbde8d538d494e49 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sun, 16 Oct 2022 21:22:38 +0800 Subject: images history speed up --- javascript/images_history.js | 39 ++++--- modules/images_history.py | 250 ++++++++++++++++++++++--------------------- 2 files changed, 147 insertions(+), 142 deletions(-) (limited to 'modules') diff --git a/javascript/images_history.js b/javascript/images_history.js index ac5834c7..fb1356d9 100644 --- a/javascript/images_history.js +++ b/javascript/images_history.js @@ -20,7 +20,7 @@ var images_history_click_image = function(){ var images_history_click_tab = function(){ var tabs_box = gradioApp().getElementById("images_history_tab"); if (!tabs_box.classList.contains(this.getAttribute("tabname"))) { - gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click(); + gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click(); tabs_box.classList.add(this.getAttribute("tabname")) } } @@ -96,7 +96,7 @@ function images_history_get_current_img(tabname, img_index, files){ ]; } -function images_history_delete(del_num, tabname, img_file_name, page_index, filenames, image_index){ +function images_history_delete(del_num, tabname, image_index){ image_index = parseInt(image_index); var tab = gradioApp().getElementById(tabname + '_images_history'); var set_btn = tab.querySelector(".images_history_set_index"); @@ -107,6 +107,7 @@ function images_history_delete(del_num, tabname, img_file_name, page_index, file } }); var img_num = buttons.length / 2; + del_num = Math.min(img_num - image_index, del_num) if (img_num <= del_num){ setTimeout(function(tabname){ gradioApp().getElementById(tabname + '_images_history_renew_page').click(); @@ -114,30 +115,29 @@ function images_history_delete(del_num, tabname, img_file_name, page_index, file } else { var next_img for (var i = 0; i < del_num; i++){ - if (image_index + i < image_index + img_num){ - buttons[image_index + i].style.display = 'none'; - buttons[image_index + img_num + 1].style.display = 'none'; - next_img = image_index + i + 1 - } + buttons[image_index + i].style.display = 'none'; + buttons[image_index + i + img_num].style.display = 'none'; + next_img = image_index + i + 1 } var bnt; if (next_img >= img_num){ - btn = buttons[image_index - del_num]; + btn = buttons[image_index - 1]; } else { btn = buttons[next_img]; } setTimeout(function(btn){btn.click()}, 30, btn); } images_history_disabled_del(); - return [del_num, tabname, img_file_name, page_index, filenames, image_index]; + } -function images_history_turnpage(img_path, page_index, image_index, tabname, date_from, date_to){ +function images_history_turnpage(tabname){ + console.log("del_button") + gradioApp().getElementById(tabname + '_images_history_del_button').setAttribute('disabled','disabled'); var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item"); buttons.forEach(function(elem) { elem.style.display = 'block'; - }) - return [img_path, page_index, image_index, tabname, date_from, date_to]; + }) } function images_history_enable_del_buttons(){ @@ -147,7 +147,7 @@ function images_history_enable_del_buttons(){ } function images_history_init(){ - var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page') + var load_txt2img_button = gradioApp().getElementById('saved_images_history_start') if (load_txt2img_button){ for (var i in images_history_tab_list ){ tab = images_history_tab_list[i]; @@ -166,7 +166,8 @@ function images_history_init(){ // this refreshes history upon tab switch // until the history is known to work well, which is not the case now, we do not do this at startup - //tab_btns[i].addEventListener('click', images_history_click_tab); + // -- load page very fast now, so better user experience by automatically activating pages + tab_btns[i].addEventListener('click', images_history_click_tab); } tabs_box.classList.add(images_history_tab_list[0]); @@ -177,7 +178,7 @@ function images_history_init(){ } } -var images_history_tab_list = ["txt2img", "img2img", "extras"]; +var images_history_tab_list = ["saved", "txt2img", "img2img", "extras"]; setTimeout(images_history_init, 500); document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ @@ -188,18 +189,16 @@ document.addEventListener("DOMContentLoaded", function() { bnt.addEventListener('click', images_history_click_image, true); }); - // same as load_txt2img_button.click() above - /* var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); if (cls_btn){ cls_btn.addEventListener('click', function(){ - gradioApp().getElementById(tabname + '_images_history_renew_page').click(); + gradioApp().getElementById(tabname + '_images_history_del_button').setAttribute('disabled','disabled'); }, false); - }*/ + } } }); - mutationObserver.observe( gradioApp(), { childList:true, subtree:true }); + mutationObserver.observe(gradioApp(), { childList:true, subtree:true }); }); diff --git a/modules/images_history.py b/modules/images_history.py index 7fd75005..ae0b4e40 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -3,8 +3,10 @@ import shutil import time import hashlib import gradio -show_max_dates_num = 3 + system_bak_path = "webui_log_and_bak" +loads_files_num = 216 +num_of_imgs_per_page = 36 def is_valid_date(date): try: time.strptime(date, "%Y%m%d") @@ -53,38 +55,7 @@ def traverse_all_files(curr_path, image_list, all_type=False): image_list = traverse_all_files(file, image_list) return image_list -def get_recent_images(dir_name, page_index, step, image_index, tabname, date_from, date_to): - #print(f"turn_page {page_index}",date_from) - if date_from is None or date_from == "": - return None, 1, None, "" - image_list = [] - date_list = auto_sorting(dir_name) - page_index = int(page_index) - today = time.strftime("%Y%m%d",time.localtime(time.time())) - for date in date_list: - if date >= date_from and date <= date_to: - path = os.path.join(dir_name, date) - if date == today and not os.path.exists(path): - continue - image_list = traverse_all_files(path, image_list) - - image_list = sorted(image_list, key=lambda file: -os.path.getctime(file)) - num = 48 if tabname != "extras" else 12 - max_page_index = len(image_list) // num + 1 - page_index = max_page_index if page_index == -1 else page_index + step - page_index = 1 if page_index < 1 else page_index - page_index = max_page_index if page_index > max_page_index else page_index - idx_frm = (page_index - 1) * num - image_list = image_list[idx_frm:idx_frm + num] - image_index = int(image_index) - if image_index < 0 or image_index > len(image_list) - 1: - current_file = None - else: - current_file = image_list[image_index] - return image_list, page_index, image_list, "" - -def auto_sorting(dir_name): - #print(f"auto sorting") +def auto_sorting(dir_name): bak_path = os.path.join(dir_name, system_bak_path) if not os.path.exists(bak_path): os.mkdir(bak_path) @@ -126,102 +97,131 @@ def auto_sorting(dir_name): today = time.strftime("%Y%m%d",time.localtime(time.time())) if today not in date_list: date_list.append(today) - return sorted(date_list) + return sorted(date_list, reverse=True) -def archive_images(dir_name): +def archive_images(dir_name, date_to): date_list = auto_sorting(dir_name) - date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0] + today = time.strftime("%Y%m%d",time.localtime(time.time())) + date_to = today if date_to is None or date_to == "" else date_to + filenames = [] + for date in date_list: + if date <= date_to: + path = os.path.join(dir_name, date) + if date == today and not os.path.exists(path): + continue + filenames = traverse_all_files(path, filenames) + if len(filenames) > loads_files_num: + break + filenames = sorted(filenames, key=lambda file: -os.path.getctime(file)) + _, image_list, _, visible_num = get_recent_images(1, 0, filenames) return ( gradio.update(visible=False), gradio.update(visible=True), - gradio.Dropdown.update(choices=date_list, value=date_list[-1]), - gradio.Dropdown.update(choices=date_list, value=date_from) + gradio.Dropdown.update(choices=date_list, value=date_to), + date, + filenames, + 1, + image_list, + "", + visible_num ) +def system_init(dir_name): + ret = [x for x in archive_images(dir_name, None)] + ret += [gradio.update(visible=False)] + return ret + +def newest_click(dir_name, date_to): + if date_to == "start": + return True, False, "start", None, None, 1, None, "" + else: + return archive_images(dir_name, time.strftime("%Y%m%d",time.localtime(time.time()))) -def date_to_change(dir_name, page_index, image_index, tabname, date_from, date_to): - #print("date_to", date_to) - date_list = auto_sorting(dir_name) - date_from_list = [date for date in date_list if date <= date_to] - date_from = date_from_list[0] if len(date_from_list) < show_max_dates_num else date_from_list[-show_max_dates_num] - image_list, page_index, image_list, _ =get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to) - return image_list, page_index, image_list, _, gradio.Dropdown.update(choices=date_from_list, value=date_from) - -def first_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): - return get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to) - - -def end_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): - return get_recent_images(dir_name, -1, 0, image_index, tabname, date_from, date_to) - - -def prev_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): - return get_recent_images(dir_name, page_index, -1, image_index, tabname, date_from, date_to) - - -def next_page_click(dir_name, page_index, image_index, tabname, date_from, date_to): - return get_recent_images(dir_name, page_index, 1, image_index, tabname, date_from, date_to) - - -def page_index_change(dir_name, page_index, image_index, tabname, date_from, date_to): - return get_recent_images(dir_name, page_index, 0, image_index, tabname, date_from, date_to) - - -def show_image_info(tabname_box, num, filenames): - # #print(f"select image {num}") - file = filenames[int(num)] - return file, num, file - -def delete_image(delete_num, tabname, name, page_index, filenames, image_index): +def delete_image(delete_num, name, filenames, image_index, visible_num): if name == "": return filenames, delete_num else: delete_num = int(delete_num) + visible_num = int(visible_num) + image_index = int(image_index) index = list(filenames).index(name) i = 0 new_file_list = [] for name in filenames: if i >= index and i < index + delete_num: if os.path.exists(name): - #print(f"Delete file {name}") + if visible_num == image_index: + new_file_list.append(name) + continue + print(f"Delete file {name}") os.remove(name) + visible_num -= 1 txt_file = os.path.splitext(name)[0] + ".txt" if os.path.exists(txt_file): os.remove(txt_file) else: - #print(f"Not exists file {name}") + print(f"Not exists file {name}") else: new_file_list.append(name) i += 1 - return new_file_list, 1 + return new_file_list, 1, visible_num + +def get_recent_images(page_index, step, filenames): + page_index = int(page_index) + max_page_index = len(filenames) // num_of_imgs_per_page + 1 + page_index = max_page_index if page_index == -1 else page_index + step + page_index = 1 if page_index < 1 else page_index + page_index = max_page_index if page_index > max_page_index else page_index + idx_frm = (page_index - 1) * num_of_imgs_per_page + image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page] + length = len(filenames) + visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page + visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num + return page_index, image_list, "", visible_num + +def first_page_click(page_index, filenames): + return get_recent_images(1, 0, filenames) + +def end_page_click(page_index, filenames): + return get_recent_images(-1, 0, filenames) + +def prev_page_click(page_index, filenames): + return get_recent_images(page_index, -1, filenames) + +def next_page_click(page_index, filenames): + return get_recent_images(page_index, 1, filenames) + +def page_index_change(page_index, filenames): + return get_recent_images(page_index, 0, filenames) + +def show_image_info(tabname_box, num, page_index, filenames): + file = filenames[int(num) + int((page_index - 1) * num_of_imgs_per_page)] + return file, num, file def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): - if opts.outdir_samples != "": - dir_name = opts.outdir_samples - elif tabname == "txt2img": + if tabname == "txt2img": dir_name = opts.outdir_txt2img_samples elif tabname == "img2img": dir_name = opts.outdir_img2img_samples elif tabname == "extras": dir_name = opts.outdir_extras_samples + elif tabname == "saved": + dir_name = opts.outdir_save + if not os.path.exists(dir_name): + os.makedirs(dir_name) d = dir_name.split("/") - dir_name = "/" if dir_name.startswith("/") else d[0] + dir_name = d[0] for p in d[1:]: dir_name = os.path.join(dir_name, p) f_list = os.listdir(dir_name) sorted_flag = os.path.exists(os.path.join(dir_name, system_bak_path)) or len(f_list) == 0 date_list, date_from, date_to = None, None, None - if sorted_flag: - #print(sorted_flag) - date_list = auto_sorting(dir_name) - date_to = date_list[-1] - date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0] with gr.Column(visible=sorted_flag) as page_panel: with gr.Row(): - renew_page = gr.Button('Refresh', elem_id=tabname + "_images_history_renew_page", interactive=sorted_flag) + #renew_page = gr.Button('Refresh') first_page = gr.Button('First Page') prev_page = gr.Button('Prev Page') page_index = gr.Number(value=1, label="Page Index") @@ -231,9 +231,9 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Row(elem_id=tabname + "_images_history"): with gr.Column(scale=2): with gr.Row(): - newest = gr.Button('Newest') - date_to = gr.Dropdown(choices=date_list, value=date_to, label="Date to") - date_from = gr.Dropdown(choices=date_list, value=date_from, label="Date from") + newest = gr.Button('Refresh', elem_id=tabname + "_images_history_start") + date_from = gr.Textbox(label="Date from", interactive=False) + date_to = gr.Dropdown(value="start" if not sorted_flag else None, label="Date to") history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) with gr.Row(): @@ -247,66 +247,72 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Column(): img_file_info = gr.Textbox(label="Generate Info", interactive=False) img_file_name = gr.Textbox(value="", label="File Name", interactive=False) + # hiden items - with gr.Row(visible=False): + with gr.Row(visible=False): + visible_img_num = gr.Number() img_path = gr.Textbox(dir_name) tabname_box = gr.Textbox(tabname) image_index = gr.Textbox(value=-1) set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") filenames = gr.State() + all_images_list = gr.State() hidden = gr.Image(type="pil") info1 = gr.Textbox() info2 = gr.Textbox() + with gr.Column(visible=not sorted_flag) as init_warning: with gr.Row(): - gr.Textbox("The system needs to archive the files according to the date. This requires changing the directory structure of the files", - label="Waring", - css="") + warning = gr.Textbox( + label="Waring", + value=f"The system needs to archive the files according to the date. This requires changing the directory structure of the files.If you have doubts about this operation, you can first back up the files in the '{dir_name}' directory" + ) + warning.style(height=100, width=50) with gr.Row(): sorted_button = gr.Button('Confirme') - - + change_date_output = [init_warning, page_panel, date_to, date_from, filenames, page_index, history_gallery, img_file_name, visible_img_num] + sorted_button.click(system_init, inputs=[img_path], outputs=change_date_output + [sorted_button]) + newest.click(newest_click, inputs=[img_path, date_to], outputs=change_date_output) + date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output) + date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + newest.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + + delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) + delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None) + # turn pages - gallery_inputs = [img_path, page_index, image_index, tabname_box, date_from, date_to] - gallery_outputs = [history_gallery, page_index, filenames, img_file_name] + gallery_inputs = [page_index, filenames] + gallery_outputs = [page_index, history_gallery, img_file_name, visible_img_num] + + first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) - first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index]) + first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") # other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, filenames], outputs=[img_file_name, image_index, hidden]) - img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) - delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden]) + img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) - date_to.change(date_to_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs + [date_from]) - # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) + switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') - sorted_button.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from]) - newest.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from]) - - - def create_history_tabs(gr, opts, run_pnginfo, switch_dict): with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: - with gr.Tab("txt2img history"): - with gr.Blocks(analytics_enabled=False) as images_history_txt2img: - show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict) - with gr.Tab("img2img history"): - with gr.Blocks(analytics_enabled=False) as images_history_img2img: - show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict) - with gr.Tab("extras history"): - with gr.Blocks(analytics_enabled=False) as images_history_img2img: - show_images_history(gr, opts, "extras", run_pnginfo, switch_dict) + for tab in ["saved", "txt2img", "img2img", "extras"]: + with gr.Tab(tab): + with gr.Blocks(analytics_enabled=False) as images_history_img2img: + show_images_history(gr, opts, tab, run_pnginfo, switch_dict) return images_history -- cgit v1.2.3 From a4de699e3c235d83b5a957d08779cb41cb0781bc Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sun, 16 Oct 2022 22:37:12 +0800 Subject: Images history speed up --- javascript/images_history.js | 1 + modules/images_history.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/javascript/images_history.js b/javascript/images_history.js index fb1356d9..9d9d04fb 100644 --- a/javascript/images_history.js +++ b/javascript/images_history.js @@ -108,6 +108,7 @@ function images_history_delete(del_num, tabname, image_index){ }); var img_num = buttons.length / 2; del_num = Math.min(img_num - image_index, del_num) + console.log(del_num, img_num) if (img_num <= del_num){ setTimeout(function(tabname){ gradioApp().getElementById(tabname + '_images_history_renew_page').click(); diff --git a/modules/images_history.py b/modules/images_history.py index ae0b4e40..94bd16a8 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -153,6 +153,7 @@ def delete_image(delete_num, name, filenames, image_index, visible_num): if os.path.exists(name): if visible_num == image_index: new_file_list.append(name) + i += 1 continue print(f"Delete file {name}") os.remove(name) @@ -221,7 +222,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Column(visible=sorted_flag) as page_panel: with gr.Row(): - #renew_page = gr.Button('Refresh') + renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") first_page = gr.Button('First Page') prev_page = gr.Button('Prev Page') page_index = gr.Number(value=1, label="Page Index") @@ -231,7 +232,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Row(elem_id=tabname + "_images_history"): with gr.Column(scale=2): with gr.Row(): - newest = gr.Button('Refresh', elem_id=tabname + "_images_history_start") + newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") date_from = gr.Textbox(label="Date from", interactive=False) date_to = gr.Dropdown(value="start" if not sorted_flag else None, label="Date to") @@ -291,12 +292,14 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs) page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) + renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") # other funcitons set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden]) -- cgit v1.2.3 From 9324cdaa3199d65c182858785dd1eca42b192b8e Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 17:53:56 +0200 Subject: ui fix, re organization of the code --- modules/aesthetic_clip.py | 154 +++++++++++++++++++++++++++++++++-- modules/img2img.py | 14 +++- modules/processing.py | 29 ++----- modules/sd_hijack.py | 102 ++--------------------- modules/sd_models.py | 5 +- modules/shared.py | 14 +++- modules/textual_inversion/dataset.py | 2 +- modules/txt2img.py | 18 ++-- modules/ui.py | 52 +++++++----- 9 files changed, 233 insertions(+), 157 deletions(-) (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index ccb35c73..34efa931 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -1,3 +1,4 @@ +import copy import itertools import os from pathlib import Path @@ -7,11 +8,12 @@ import gc import gradio as gr import torch from PIL import Image -from modules import shared -from modules.shared import device -from transformers import CLIPModel, CLIPProcessor +from torch import optim -from tqdm.auto import tqdm +from modules import shared +from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer +from tqdm.auto import tqdm, trange +from modules.shared import opts, device def get_all_images_in_folder(folder): @@ -37,12 +39,39 @@ def iter_to_batched(iterable, n=1): yield chunk +def create_ui(): + with gr.Group(): + with gr.Accordion("Open for Clip Aesthetic!", open=False): + with gr.Row(): + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", + value=0.9) + aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) + + with gr.Row(): + aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', + placeholder="Aesthetic learning rate", value="0.0001") + aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()), + label="Aesthetic imgs embedding", + value="None") + + with gr.Row(): + aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', + placeholder="This text is used to rotate the feature space of the imgs embs", + value="") + aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01, + value=0.1) + aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + + return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative + + def generate_imgs_embd(name, folder, batch_size): # clipModel = CLIPModel.from_pretrained( # shared.sd_model.cond_stage_model.clipModel.name_or_path # ) - model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device) - processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path) + model = shared.clip_model.to(device) + processor = CLIPProcessor.from_pretrained(model.name_or_path) with torch.no_grad(): embs = [] @@ -63,7 +92,6 @@ def generate_imgs_embd(name, folder, batch_size): torch.save(embs, path) model = model.cpu() - del model del processor del embs gc.collect() @@ -74,4 +102,114 @@ def generate_imgs_embd(name, folder, batch_size): """ shared.update_aesthetic_embeddings() return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", - value="None"), res, "" + value="None"), \ + gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), + label="Imgs embedding", + value="None"), res, "" + + +def slerp(low, high, val): + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm * high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high + return res + + +class AestheticCLIP: + def __init__(self): + self.skip = False + self.aesthetic_steps = 0 + self.aesthetic_weight = 0 + self.aesthetic_lr = 0 + self.slerp = False + self.aesthetic_text_negative = "" + self.aesthetic_slerp_angle = 0 + self.aesthetic_imgs_text = "" + + self.image_embs_name = None + self.image_embs = None + self.load_image_embs(None) + + def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, + aesthetic_slerp=True, aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False): + self.aesthetic_imgs_text = aesthetic_imgs_text + self.aesthetic_slerp_angle = aesthetic_slerp_angle + self.aesthetic_text_negative = aesthetic_text_negative + self.slerp = aesthetic_slerp + self.aesthetic_lr = aesthetic_lr + self.aesthetic_weight = aesthetic_weight + self.aesthetic_steps = aesthetic_steps + self.load_image_embs(image_embs_name) + + def set_skip(self, skip): + self.skip = skip + + def load_image_embs(self, image_embs_name): + if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": + image_embs_name = None + self.image_embs_name = None + if image_embs_name is not None and self.image_embs_name != image_embs_name: + self.image_embs_name = image_embs_name + self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) + self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) + self.image_embs.requires_grad_(False) + + def __call__(self, z, remade_batch_tokens): + if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None: + tokenizer = shared.sd_model.cond_stage_model.tokenizer + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [ + [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in + remade_batch_tokens] + + tokens = torch.asarray(remade_batch_tokens).to(device) + + model = copy.deepcopy(shared.clip_model).to(device) + model.requires_grad_(True) + if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: + text_embs_2 = model.get_text_features( + **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) + if self.aesthetic_text_negative: + text_embs_2 = self.image_embs - text_embs_2 + text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) + img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) + else: + img_embs = self.image_embs + + with torch.enable_grad(): + + # We optimize the model to maximize the similarity + optimizer = optim.Adam( + model.text_model.parameters(), lr=self.aesthetic_lr + ) + + for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"): + text_embs = model.get_text_features(input_ids=tokens) + text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) + sim = text_embs @ img_embs.T + loss = -sim + optimizer.zero_grad() + loss.mean().backward() + optimizer.step() + + zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + if opts.CLIP_stop_at_last_layers > 1: + zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] + zn = model.text_model.final_layer_norm(zn) + else: + zn = zn.last_hidden_state + model.cpu() + del model + gc.collect() + torch.cuda.empty_cache() + zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1) + if self.slerp: + z = slerp(z, zn, self.aesthetic_weight) + else: + z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight + + return z diff --git a/modules/img2img.py b/modules/img2img.py index 24126774..4ed80c4b 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -56,7 +56,14 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, + aesthetic_lr=0, + aesthetic_weight=0, aesthetic_steps=0, + aesthetic_imgs=None, + aesthetic_slerp=False, + aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False, *args): is_inpaint = mode == 1 is_batch = mode == 2 @@ -109,6 +116,11 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpainting_mask_invert=inpainting_mask_invert, ) + shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), + aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative) + if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index 1db26c3e..685f9fcd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -146,7 +146,8 @@ class Processed: self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1 - self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 + self.subseed = int( + self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.all_prompts = all_prompts or [self.prompt] self.all_seeds = all_seeds or [self.seed] @@ -332,16 +333,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() -def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, - aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", - aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False) -> Processed: +def process_images(p: StableDiffusionProcessing) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - aesthetic_lr = float(aesthetic_lr) - aesthetic_weight = float(aesthetic_weight) - aesthetic_steps = int(aesthetic_steps) - if type(p.prompt) == list: assert (len(p.prompt) > 0) else: @@ -417,16 +411,10 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh # uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) # c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): - if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): - shared.sd_model.cond_stage_model.set_aesthetic_params() + shared.aesthetic_clip.set_skip(True) uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) - if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): - shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight, - aesthetic_steps, aesthetic_imgs, - aesthetic_slerp, aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative) + shared.aesthetic_clip.set_skip(False) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: @@ -582,7 +570,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) @@ -600,10 +587,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) - samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] + samples = samples[:, :, self.truncate_y // 2:samples.shape[2] - self.truncate_y // 2, + self.truncate_x // 2:samples.shape[3] - self.truncate_x // 2] if opts.use_scale_latent_for_hires_fix: - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), + mode="bilinear") else: decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5d0590af..227e7670 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -29,8 +29,8 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and ( + 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward @@ -118,33 +118,14 @@ class StableDiffusionModelHijack: return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) -def slerp(low, high, val): - low_norm = low / torch.norm(low, dim=1, keepdim=True) - high_norm = high / torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm * high_norm).sum(1)) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high - return res - - class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.clipModel = CLIPModel.from_pretrained( - self.wrapped.transformer.name_or_path - ) - del self.clipModel.vision_model - self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path) - self.hijack: StableDiffusionModelHijack = hijack - self.tokenizer = wrapped.tokenizer - # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() - self.image_embs_name = None - self.image_embs = None - self.load_image_embs(None) self.token_mults = {} - + self.hijack: StableDiffusionModelHijack = hijack + self.tokenizer = wrapped.tokenizer self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if @@ -164,28 +145,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, - aesthetic_slerp=True, aesthetic_imgs_text="", - aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False): - self.aesthetic_imgs_text = aesthetic_imgs_text - self.aesthetic_slerp_angle = aesthetic_slerp_angle - self.aesthetic_text_negative = aesthetic_text_negative - self.slerp = aesthetic_slerp - self.aesthetic_lr = aesthetic_lr - self.aesthetic_weight = aesthetic_weight - self.aesthetic_steps = aesthetic_steps - self.load_image_embs(image_embs_name) - - def load_image_embs(self, image_embs_name): - if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": - image_embs_name = None - if image_embs_name is not None and self.image_embs_name != image_embs_name: - self.image_embs_name = image_embs_name - self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) - self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) - self.image_embs.requires_grad_(False) - def tokenize_line(self, line, used_custom_terms, hijack_comments): id_end = self.wrapped.tokenizer.eos_token_id @@ -391,58 +350,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - - if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [ - [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in - remade_batch_tokens] - - tokens = torch.asarray(remade_batch_tokens).to(device) - - model = copy.deepcopy(self.clipModel).to(device) - model.requires_grad_(True) - if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: - text_embs_2 = model.get_text_features( - **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) - if self.aesthetic_text_negative: - text_embs_2 = self.image_embs - text_embs_2 - text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) - img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) - else: - img_embs = self.image_embs - - with torch.enable_grad(): - - # We optimize the model to maximize the similarity - optimizer = optim.Adam( - model.text_model.parameters(), lr=self.aesthetic_lr - ) - - for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"): - text_embs = model.get_text_features(input_ids=tokens) - text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ img_embs.T - loss = -sim - optimizer.zero_grad() - loss.mean().backward() - optimizer.step() - - zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - if opts.CLIP_stop_at_last_layers > 1: - zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] - zn = model.text_model.final_layer_norm(zn) - else: - zn = zn.last_hidden_state - model.cpu() - del model - - zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1) - if self.slerp: - z = slerp(z, zn, self.aesthetic_weight) - else: - z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight - + z = shared.aesthetic_clip(z, remade_batch_tokens) remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 diff --git a/modules/sd_models.py b/modules/sd_models.py index 3aa21ec1..8e4ee435 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -20,7 +20,7 @@ checkpoints_loaded = collections.OrderedDict() try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging + from transformers import logging, CLIPModel logging.set_verbosity_error() except Exception: @@ -196,6 +196,9 @@ def load_model(): sd_hijack.model_hijack.hijack(sd_model) + if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path: + shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path) + sd_model.eval() print(f"Model loaded.") diff --git a/modules/shared.py b/modules/shared.py index e2c98b2d..e19ca779 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -3,6 +3,7 @@ import datetime import json import os import sys +from collections import OrderedDict import gradio as gr import tqdm @@ -94,15 +95,15 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None -aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in - os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} -aesthetic_embeddings = aesthetic_embeddings | {"None": None} +aesthetic_embeddings = {} def update_aesthetic_embeddings(): global aesthetic_embeddings aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} - aesthetic_embeddings = aesthetic_embeddings | {"None": None} + aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings) + +update_aesthetic_embeddings() def reload_hypernetworks(): global hypernetworks @@ -381,6 +382,11 @@ sd_upscalers = [] sd_model = None +clip_model = None + +from modules.aesthetic_clip import AestheticCLIP +aesthetic_clip = AestheticCLIP() + progress_print_out = sys.stdout diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 68ceffe3..23bb4b6a 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -49,7 +49,7 @@ class PersonalizedBase(Dataset): print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC) + image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) except Exception: continue diff --git a/modules/txt2img.py b/modules/txt2img.py index 8f394d05..6cbc50fc 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,12 +1,17 @@ import modules.scripts -from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ + StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, cmd_opts import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int,aesthetic_lr=0, +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, + restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, + subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, + height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, + firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, @@ -41,15 +46,17 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: firstphase_height=firstphase_height if enable_hr else None, ) + shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), + aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, + aesthetic_text_negative) + if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: - processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative) + processed = process_images(p) shared.total_tqdm.clear() @@ -61,4 +68,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: processed.images = [] return processed.images, generation_info_js, plaintext_to_html(processed.info) - diff --git a/modules/ui.py b/modules/ui.py index 4069f0d2..0e5d73f0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -43,7 +43,7 @@ from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.aesthetic_clip +import modules.aesthetic_clip as aesthetic_clip import modules.images_history as img_his @@ -593,23 +593,25 @@ def create_ui(wrap_gradio_gpu_call): width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - with gr.Group(): - with gr.Accordion("Open for Clip Aesthetic!",open=False): - with gr.Row(): - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) - - with gr.Row(): - aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") - aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) - aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), - label="Aesthetic imgs embedding", - value="None") - - with gr.Row(): - aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") - aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) - aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + # with gr.Group(): + # with gr.Accordion("Open for Clip Aesthetic!",open=False): + # with gr.Row(): + # aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) + # aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) + # + # with gr.Row(): + # aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") + # aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + # aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), + # label="Aesthetic imgs embedding", + # value="None") + # + # with gr.Row(): + # aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") + # aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) + # aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + + aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui() with gr.Row(): @@ -840,6 +842,9 @@ def create_ui(wrap_gradio_gpu_call): width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui() + + with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) tiling = gr.Checkbox(label='Tiling', value=False) @@ -944,6 +949,14 @@ def create_ui(wrap_gradio_gpu_call): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, + aesthetic_lr_im, + aesthetic_weight_im, + aesthetic_steps_im, + aesthetic_imgs_im, + aesthetic_slerp_im, + aesthetic_imgs_text_im, + aesthetic_slerp_angle_im, + aesthetic_text_negative_im, ] + custom_inputs, outputs=[ img2img_gallery, @@ -1283,7 +1296,7 @@ def create_ui(wrap_gradio_gpu_call): ) create_embedding_ae.click( - fn=modules.aesthetic_clip.generate_imgs_embd, + fn=aesthetic_clip.generate_imgs_embd, inputs=[ new_embedding_name_ae, process_src_ae, @@ -1291,6 +1304,7 @@ def create_ui(wrap_gradio_gpu_call): ], outputs=[ aesthetic_imgs, + aesthetic_imgs_im, ti_output, ti_outcome, ] -- cgit v1.2.3 From 9d702b16f01795c3af900e0ebd70faf4b25200f6 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 17 Oct 2022 16:11:03 +0800 Subject: fix two little bug --- modules/images_history.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 23045df1..1ae168ca 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -133,7 +133,7 @@ def archive_images(dir_name, date_to): date = sort_array[loads_num][2] filenames = [x[1] for x in sort_array] else: - date = sort_array[loads_num][2] + date = sort_array[-1][2] filenames = [x[1] for x in sort_array] filenames = [x[1] for x in sort_array if x[2]>= date] _, image_list, _, visible_num = get_recent_images(1, 0, filenames) @@ -334,6 +334,6 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): with gr.Tab(tab): with gr.Blocks(analytics_enabled=False) as images_history_img2img: show_images_history(gr, opts, tab, run_pnginfo, switch_dict) - gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory") #, visible=False) + gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False) return images_history -- cgit v1.2.3 From c408a0b41cfffde184cad35b2d97346342947d83 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 17 Oct 2022 22:28:43 +0800 Subject: fix two bug --- launch.py | 1 - modules/images_history.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/launch.py b/launch.py index 7520cfee..088eada1 100644 --- a/launch.py +++ b/launch.py @@ -11,7 +11,6 @@ python = sys.executable git = os.environ.get('GIT', "git") index_url = os.environ.get('INDEX_URL', "") - def extract_arg(args, name): return [x for x in args if x != name], name in args diff --git a/modules/images_history.py b/modules/images_history.py index 1ae168ca..10e5b970 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -181,7 +181,8 @@ def delete_image(delete_num, name, filenames, image_index, visible_num): return new_file_list, 1, visible_num def save_image(file_name): - shutil.copy2(file_name, opts.outdir_save) + if file_name is not None and os.path.exists(file_name): + shutil.copy2(file_name, opts.outdir_save) def get_recent_images(page_index, step, filenames): page_index = int(page_index) @@ -327,7 +328,6 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): opts = sys_opts loads_files_num = int(opts.images_history_num_per_page) num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) - backup_flag = opts.images_history_backup with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: for tab in ["txt2img", "img2img", "extras", "saved"]: -- cgit v1.2.3 From 2272cf2f35fafd5cd486bfb4ee89df5bbc625b97 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 17 Oct 2022 23:04:42 +0800 Subject: fix two bug --- modules/images_history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 10e5b970..1c1790a4 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -133,7 +133,7 @@ def archive_images(dir_name, date_to): date = sort_array[loads_num][2] filenames = [x[1] for x in sort_array] else: - date = sort_array[-1][2] + date = None if len(sort_array) == 0 else sort_array[-1][2] filenames = [x[1] for x in sort_array] filenames = [x[1] for x in sort_array if x[2]>= date] _, image_list, _, visible_num = get_recent_images(1, 0, filenames) -- cgit v1.2.3 From 2b5b62e768d892773a7ec1d5e8d8cea23aae1254 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 17 Oct 2022 23:14:03 +0800 Subject: fix two bug --- modules/images_history.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index 1c1790a4..20324557 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -44,7 +44,7 @@ def traverse_all_files(curr_path, image_list, all_type=False): return image_list for file in f_list: file = os.path.join(curr_path, file) - if (not all_type) and file[-4:] == ".txt": + if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"): pass elif os.path.isfile(file) and file[-10:].rfind(".") > 0: image_list.append(file) @@ -182,7 +182,7 @@ def delete_image(delete_num, name, filenames, image_index, visible_num): def save_image(file_name): if file_name is not None and os.path.exists(file_name): - shutil.copy2(file_name, opts.outdir_save) + shutil.copy(file_name, opts.outdir_save) def get_recent_images(page_index, step, filenames): page_index = int(page_index) -- cgit v1.2.3 From eb299527b1e5d1f83a14641647fca72e8fb305ac Mon Sep 17 00:00:00 2001 From: yfszzx Date: Tue, 18 Oct 2022 20:14:11 +0800 Subject: Image browser --- javascript/images_history.js | 19 ++-- modules/images_history.py | 227 ++++++++++++++++++++++++++++--------------- modules/shared.py | 7 +- modules/ui.py | 2 +- uitest.bat | 2 + uitest.py | 124 +++++++++++++++++++++++ 6 files changed, 289 insertions(+), 92 deletions(-) create mode 100644 uitest.bat create mode 100644 uitest.py (limited to 'modules') diff --git a/javascript/images_history.js b/javascript/images_history.js index 3c028bc6..182d730b 100644 --- a/javascript/images_history.js +++ b/javascript/images_history.js @@ -145,9 +145,10 @@ function images_history_enable_del_buttons(){ } function images_history_init(){ - var loaded = gradioApp().getElementById("images_history_reconstruct_directory") - if (loaded){ - var init_status = loaded.querySelector("input").checked + // var loaded = gradioApp().getElementById("images_history_reconstruct_directory") + // if (loaded){ + // var init_status = loaded.querySelector("input").checked + if (gradioApp().getElementById("images_history_finish_render")){ for (var i in images_history_tab_list ){ tab = images_history_tab_list[i]; gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor"); @@ -163,19 +164,17 @@ function images_history_init(){ for (var i in images_history_tab_list){ var tabname = images_history_tab_list[i] tab_btns[i].setAttribute("tabname", tabname); - if (init_status){ - tab_btns[i].addEventListener('click', images_history_click_tab); - } - } - if (init_status){ - tab_btns[0].click(); + // if (!init_status){ + // tab_btns[i].addEventListener('click', images_history_click_tab); + // } + tab_btns[i].addEventListener('click', images_history_click_tab); } } else { setTimeout(images_history_init, 500); } } -var images_history_tab_list = ["txt2img", "img2img", "extras", "saved"]; +var images_history_tab_list = ["custom", "txt2img", "img2img", "extras", "saved"]; setTimeout(images_history_init, 500); document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ diff --git a/modules/images_history.py b/modules/images_history.py index 20324557..d56f3a25 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -4,6 +4,7 @@ import time import hashlib import gradio system_bak_path = "webui_log_and_bak" +browser_tabname = "custom" def is_valid_date(date): try: time.strptime(date, "%Y%m%d") @@ -99,13 +100,15 @@ def auto_sorting(dir_name): date_list.append(today) return sorted(date_list, reverse=True) -def archive_images(dir_name, date_to): +def archive_images(dir_name, date_to): + filenames = [] loads_num =int(opts.images_history_num_per_page * opts.images_history_pages_num) + today = time.strftime("%Y%m%d",time.localtime(time.time())) + date_to = today if date_to is None or date_to == "" else date_to + date_to_bak = date_to if opts.images_history_reconstruct_directory: - date_list = auto_sorting(dir_name) - today = time.strftime("%Y%m%d",time.localtime(time.time())) - date_to = today if date_to is None or date_to == "" else date_to + date_list = auto_sorting(dir_name) for date in date_list: if date <= date_to: path = os.path.join(dir_name, date) @@ -120,7 +123,7 @@ def archive_images(dir_name, date_to): tmparray = [(os.path.getmtime(file), file) for file in filenames ] date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 filenames = [] - date_list = {} + date_list = {date_to:None} date = time.strftime("%Y%m%d",time.localtime(time.time())) for t, f in tmparray: date = time.strftime("%Y%m%d",time.localtime(t)) @@ -133,22 +136,29 @@ def archive_images(dir_name, date_to): date = sort_array[loads_num][2] filenames = [x[1] for x in sort_array] else: - date = None if len(sort_array) == 0 else sort_array[-1][2] + date = date_to if len(sort_array) == 0 else sort_array[-1][2] filenames = [x[1] for x in sort_array] - filenames = [x[1] for x in sort_array if x[2]>= date] - _, image_list, _, visible_num = get_recent_images(1, 0, filenames) + filenames = [x[1] for x in sort_array if x[2]>= date] + num = len(filenames) + last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) + date = date[:4] + "-" + date[4:6] + "-" + date[6:8] + date_to_bak = date_to_bak[:4] + "-" + date_to_bak[4:6] + "-" + date_to_bak[6:8] + load_info = f"Loaded {(num + 1) // opts.images_history_pages_num} pades, {num} images, during {date} - {date_to_bak}" + _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) return ( gradio.Dropdown.update(choices=date_list, value=date_to), - date, + load_info, filenames, 1, image_list, "", - visible_num + "", + visible_num, + last_date_from ) -def newest_click(dir_name, date_to): - return archive_images(dir_name, time.strftime("%Y%m%d",time.localtime(time.time()))) + + def delete_image(delete_num, name, filenames, image_index, visible_num): if name == "": @@ -196,7 +206,29 @@ def get_recent_images(page_index, step, filenames): length = len(filenames) visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num - return page_index, image_list, "", visible_num + return page_index, image_list, "", "", visible_num + +def newest_click(date_to): + if date_to is None: + return time.strftime("%Y%m%d",time.localtime(time.time())), [] + else: + return None, [] +def forward_click(last_date_from, date_to_recorder): + if len(date_to_recorder) == 0: + return None, [] + if last_date_from == date_to_recorder[-1]: + date_to_recorder = date_to_recorder[:-1] + if len(date_to_recorder) == 0: + return None, [] + return date_to_recorder[-1], date_to_recorder[:-1] + +def backward_click(last_date_from, date_to_recorder): + if last_date_from is None or last_date_from == "": + return time.strftime("%Y%m%d",time.localtime(time.time())), [] + if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]: + date_to_recorder.append(last_date_from) + return last_date_from, date_to_recorder + def first_page_click(page_index, filenames): return get_recent_images(1, 0, filenames) @@ -214,13 +246,33 @@ def page_index_change(page_index, filenames): return get_recent_images(page_index, 0, filenames) def show_image_info(tabname_box, num, page_index, filenames): - file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] - return file, num, file + file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] + tm = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + return file, tm, num, file def enable_page_buttons(): return gradio.update(visible=True) +def change_dir(img_dir, date_to): + warning = None + try: + if os.path.exists(img_dir): + try: + f = os.listdir(img_dir) + except: + warning = f"'{img_dir} is not a directory" + else: + warning = "The directory is not exist" + except: + warning = "The format of the directory is incorrect" + if warning is None: + today = time.strftime("%Y%m%d",time.localtime(time.time())) + return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today + else: + return gradio.update(visible=True), gradio.update(visible=False), warning, date_to + def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): + custom_dir = False if tabname == "txt2img": dir_name = opts.outdir_txt2img_samples elif tabname == "img2img": @@ -229,69 +281,85 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): dir_name = opts.outdir_extras_samples elif tabname == "saved": dir_name = opts.outdir_save + else: + custom_dir = True + dir_name = None + + if not custom_dir: + d = dir_name.split("/") + dir_name = d[0] + for p in d[1:]: + dir_name = os.path.join(dir_name, p) + if not os.path.exists(dir_name): + os.makedirs(dir_name) - d = dir_name.split("/") - dir_name = d[0] - for p in d[1:]: - dir_name = os.path.join(dir_name, p) - if not os.path.exists(dir_name): - os.makedirs(dir_name) - - with gr.Column() as page_panel: - with gr.Row(visible=False) as turn_page_buttons: - renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") - first_page = gr.Button('First Page') - prev_page = gr.Button('Prev Page') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next Page') - end_page = gr.Button('End Page') - - with gr.Row(elem_id=tabname + "_images_history"): - with gr.Column(scale=2): - with gr.Row(): - newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") - date_from = gr.Textbox(label="Date from", interactive=False) - date_to = gr.Dropdown(label="Date to") - - history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) - with gr.Row(): - delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") - delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - - with gr.Column(): - with gr.Row(): - if tabname != "saved": - save_btn = gr.Button('Save') - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') - with gr.Row(): - with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info", interactive=False) - img_file_name = gr.Textbox(value="", label="File Name", interactive=False) + with gr.Column() as page_panel: + with gr.Row(): + img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory") + with gr.Row(visible=False) as warning: + warning_box = gr.Textbox("Message", interactive=False) + with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: + with gr.Column(scale=2): + with gr.Row(): + backward = gr.Button('Backward') + date_to = gr.Dropdown(label="Date to") + forward = gr.Button('Forward') + newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") + with gr.Row(): + load_info = gr.Textbox(show_label=False, interactive=False) + with gr.Row(visible=False) as turn_page_buttons: + renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") + first_page = gr.Button('First Page') + prev_page = gr.Button('Prev Page') + page_index = gr.Number(value=1, label="Page Index") + next_page = gr.Button('Next Page') + end_page = gr.Button('End Page') + + history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num) + with gr.Row(): + delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") + delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - # hiden items - with gr.Row(visible=False): - visible_img_num = gr.Number() - img_path = gr.Textbox(dir_name) - tabname_box = gr.Textbox(tabname) - image_index = gr.Textbox(value=-1) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") - filenames = gr.State() - all_images_list = gr.State() - hidden = gr.Image(type="pil") - info1 = gr.Textbox() - info2 = gr.Textbox() + with gr.Column(): + with gr.Row(): + if tabname != "saved": + save_btn = gr.Button('Save') + pnginfo_send_to_txt2img = gr.Button('Send to txt2img') + pnginfo_send_to_img2img = gr.Button('Send to img2img') + with gr.Row(): + with gr.Column(): + img_file_info = gr.Textbox(label="Generate Info", interactive=False) + img_file_name = gr.Textbox(value="", label="File Name", interactive=False) + img_file_time= gr.Textbox(value="", label="Create Time", interactive=False) - + + # hiden items + with gr.Row(): #visible=False): + visible_img_num = gr.Number() + date_to_recorder = gr.State([]) + last_date_from = gr.Textbox() + tabname_box = gr.Textbox(tabname) + image_index = gr.Textbox(value=-1) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") + filenames = gr.State() + all_images_list = gr.State() + hidden = gr.Image(type="pil") + info1 = gr.Textbox() + info2 = gr.Textbox() + + img_path.submit(change_dir, inputs=[img_path, date_to], outputs=[warning, main_panel, warning_box, date_to]) #change date - change_date_output = [date_to, date_from, filenames, page_index, history_gallery, img_file_name, visible_img_num] - newest.click(newest_click, inputs=[img_path, date_to], outputs=change_date_output) - date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output) - newest.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) - newest.click(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) + change_date_output = [date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from] + + date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output) + date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) + date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + + newest.click(newest_click, inputs=[date_to], outputs=[date_to, date_to_recorder]) + forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) + backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) + #delete delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) @@ -301,7 +369,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): #turn page gallery_inputs = [page_index, filenames] - gallery_outputs = [page_index, history_gallery, img_file_name, visible_img_num] + gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num] first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs) next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs) prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) @@ -317,12 +385,14 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") # other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden]) + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden]) img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') + + def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): global opts; opts = sys_opts @@ -330,10 +400,11 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: - for tab in ["txt2img", "img2img", "extras", "saved"]: + for tab in [browser_tabname, "txt2img", "img2img", "extras", "saved"]: with gr.Tab(tab): - with gr.Blocks(analytics_enabled=False) as images_history_img2img: + with gr.Blocks(analytics_enabled=False) : show_images_history(gr, opts, tab, run_pnginfo, switch_dict) - gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False) + #gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False) + gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_finish_render", visible=False) return images_history diff --git a/modules/shared.py b/modules/shared.py index c2ea4186..1811018d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -309,10 +309,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) -options_templates.update(options_section(('images-history', "Images history"), { - "images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), +options_templates.update(options_section(('images-history', "Images Browser"), { + #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"), - "images_history_pages_num": OptionInfo(6, "Maximum number of pages per load "), + "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "), + "images_history_grid_num": OptionInfo(6, "Number of grids in each row"), })) diff --git a/modules/ui.py b/modules/ui.py index 43dc88fc..85abac4d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1548,7 +1548,7 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (images_history, "History", "images_history"), + (images_history, "Image Browser", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), (settings_interface, "Settings", "settings"), diff --git a/uitest.bat b/uitest.bat new file mode 100644 index 00000000..ae863af6 --- /dev/null +++ b/uitest.bat @@ -0,0 +1,2 @@ +venv\Scripts\python.exe uitest.py +pause diff --git a/uitest.py b/uitest.py new file mode 100644 index 00000000..393e2d81 --- /dev/null +++ b/uitest.py @@ -0,0 +1,124 @@ +import os +import threading +import time +import importlib +import signal +import threading + +from modules.paths import script_path + +from modules import devices, sd_samplers +import modules.codeformer_model as codeformer +import modules.extras +import modules.face_restoration +import modules.gfpgan_model as gfpgan +import modules.img2img + +import modules.lowvram +import modules.paths +import modules.scripts +import modules.sd_hijack +import modules.sd_models +import modules.shared as shared +import modules.txt2img + +import modules.ui +from modules import devices +from modules import modelloader +from modules.paths import script_path +from modules.shared import cmd_opts + +modelloader.cleanup_models() +modules.sd_models.setup_model() +codeformer.setup_model(cmd_opts.codeformer_models_path) +gfpgan.setup_model(cmd_opts.gfpgan_models_path) +shared.face_restorers.append(modules.face_restoration.FaceRestoration()) +modelloader.load_upscalers() +queue_lock = threading.Lock() + + +def wrap_queued_call(func): + def f(*args, **kwargs): + with queue_lock: + res = func(*args, **kwargs) + + return res + + return f + + +def wrap_gradio_gpu_call(func, extra_outputs=None): + def f(*args, **kwargs): + devices.torch_gc() + + shared.state.sampling_step = 0 + shared.state.job_count = -1 + shared.state.job_no = 0 + shared.state.job_timestamp = shared.state.get_job_timestamp() + shared.state.current_latent = None + shared.state.current_image = None + shared.state.current_image_sampling_step = 0 + shared.state.interrupted = False + shared.state.textinfo = None + + with queue_lock: + res = func(*args, **kwargs) + + shared.state.job = "" + shared.state.job_count = 0 + + devices.torch_gc() + + return res + + return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) + + +modules.scripts.load_scripts(os.path.join(script_path, "scripts")) + +shared.sd_model = None #modules.sd_models.load_model() +#shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) + + +def webui(): + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f'Interrupted with signal {sig} in {frame}') + os._exit(0) + + signal.signal(signal.SIGINT, sigint_handler) + + while 1: + + demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + + demo.launch( + share=cmd_opts.share, + server_name="0.0.0.0" if cmd_opts.listen else None, + server_port=cmd_opts.port, + debug=cmd_opts.gradio_debug, + auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, + inbrowser=cmd_opts.autolaunch, + prevent_thread_lock=True + ) + + while 1: + time.sleep(0.5) + if getattr(demo, 'do_restart', False): + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + + sd_samplers.set_samplers() + + print('Reloading Custom Scripts') + modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) + print('Reloading modules: modules.ui') + importlib.reload(modules.ui) + print('Restarting Gradio') + + + +if __name__ == "__main__": + webui() \ No newline at end of file -- cgit v1.2.3 From b7e78ef692fe912916de6e54f6e2521b000d650c Mon Sep 17 00:00:00 2001 From: yfszzx Date: Tue, 18 Oct 2022 22:21:54 +0800 Subject: Image browser improve --- modules/images_history.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) (limited to 'modules') diff --git a/modules/images_history.py b/modules/images_history.py index d56f3a25..a40cdc0e 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -100,14 +100,15 @@ def auto_sorting(dir_name): date_list.append(today) return sorted(date_list, reverse=True) -def archive_images(dir_name, date_to): - +def archive_images(dir_name, date_to): filenames = [] - loads_num =int(opts.images_history_num_per_page * opts.images_history_pages_num) + batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num) + if batch_size <= 0: + batch_size = opts.images_history_num_per_page * 6 today = time.strftime("%Y%m%d",time.localtime(time.time())) date_to = today if date_to is None or date_to == "" else date_to date_to_bak = date_to - if opts.images_history_reconstruct_directory: + if False: #opts.images_history_reconstruct_directory: date_list = auto_sorting(dir_name) for date in date_list: if date <= date_to: @@ -115,11 +116,13 @@ def archive_images(dir_name, date_to): if date == today and not os.path.exists(path): continue filenames = traverse_all_files(path, filenames) - if len(filenames) > loads_num: + if len(filenames) > batch_size: break filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file)) else: - filenames = traverse_all_files(dir_name, filenames) + filenames = traverse_all_files(dir_name, filenames) + total_num = len(filenames) + batch_count = len(filenames) + 1 // batch_size + 1 tmparray = [(os.path.getmtime(file), file) for file in filenames ] date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 filenames = [] @@ -132,8 +135,8 @@ def archive_images(dir_name, date_to): filenames.append((t, f ,date)) date_list = sorted(list(date_list.keys()), reverse=True) sort_array = sorted(filenames, key=lambda x:-x[0]) - if len(sort_array) > loads_num: - date = sort_array[loads_num][2] + if len(sort_array) > batch_size: + date = sort_array[batch_size][2] filenames = [x[1] for x in sort_array] else: date = date_to if len(sort_array) == 0 else sort_array[-1][2] @@ -141,9 +144,9 @@ def archive_images(dir_name, date_to): filenames = [x[1] for x in sort_array if x[2]>= date] num = len(filenames) last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) - date = date[:4] + "-" + date[4:6] + "-" + date[6:8] - date_to_bak = date_to_bak[:4] + "-" + date_to_bak[4:6] + "-" + date_to_bak[6:8] - load_info = f"Loaded {(num + 1) // opts.images_history_pages_num} pades, {num} images, during {date} - {date_to_bak}" + date = date[:4] + "/" + date[4:6] + "/" + date[6:8] + date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8] + load_info = f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) return ( gradio.Dropdown.update(choices=date_list, value=date_to), @@ -154,12 +157,10 @@ def archive_images(dir_name, date_to): "", "", visible_num, - last_date_from + last_date_from, + #gradio.update(visible=batch_count > 1) ) - - - def delete_image(delete_num, name, filenames, image_index, visible_num): if name == "": return filenames, delete_num @@ -295,16 +296,16 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): with gr.Column() as page_panel: with gr.Row(): - img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory") + img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) with gr.Row(visible=False) as warning: warning_box = gr.Textbox("Message", interactive=False) with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: with gr.Column(scale=2): - with gr.Row(): - backward = gr.Button('Backward') - date_to = gr.Dropdown(label="Date to") - forward = gr.Button('Forward') + with gr.Row() as batch_panel: + forward = gr.Button('Forward') + date_to = gr.Dropdown(label="Date to") + backward = gr.Button('Backward') newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") with gr.Row(): load_info = gr.Textbox(show_label=False, interactive=False) @@ -335,7 +336,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): # hiden items - with gr.Row(): #visible=False): + with gr.Row(visible=False): visible_img_num = gr.Number() date_to_recorder = gr.State([]) last_date_from = gr.Textbox() -- cgit v1.2.3 From 538bc89c269743e56b07ef2b471d1ce0a39b6776 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Wed, 19 Oct 2022 11:27:51 +0800 Subject: Image browser improved --- javascript/images_history.js | 87 ++++++++++++++-------------- modules/images_history.py | 135 ++++++++++++++++++++++++------------------- modules/shared.py | 5 ++ modules/ui.py | 2 +- 4 files changed, 123 insertions(+), 106 deletions(-) (limited to 'modules') diff --git a/javascript/images_history.js b/javascript/images_history.js index 182d730b..c9aa76f8 100644 --- a/javascript/images_history.js +++ b/javascript/images_history.js @@ -17,14 +17,6 @@ var images_history_click_image = function(){ images_history_set_image_info(this); } -var images_history_click_tab = function(){ - var tabs_box = gradioApp().getElementById("images_history_tab"); - if (!tabs_box.classList.contains(this.getAttribute("tabname"))) { - gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click(); - tabs_box.classList.add(this.getAttribute("tabname")) - } -} - function images_history_disabled_del(){ gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){ btn.setAttribute('disabled','disabled'); @@ -145,57 +137,64 @@ function images_history_enable_del_buttons(){ } function images_history_init(){ - // var loaded = gradioApp().getElementById("images_history_reconstruct_directory") - // if (loaded){ - // var init_status = loaded.querySelector("input").checked - if (gradioApp().getElementById("images_history_finish_render")){ + var tabnames = gradioApp().getElementById("images_history_tabnames_list") + if (tabnames){ + images_history_tab_list = tabnames.querySelector("textarea").value.split(",") for (var i in images_history_tab_list ){ - tab = images_history_tab_list[i]; + var tab = images_history_tab_list[i]; gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor"); gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index"); gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button"); - gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery"); - + gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery"); + gradioApp().getElementById(tab + "_images_history_start").setAttribute("style","padding:20px;font-size:25px"); + } + + //preload + if (gradioApp().getElementById("images_history_preload").querySelector("input").checked ){ + var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div"); + tabs_box.setAttribute("id", "images_history_tab"); + var tab_btns = tabs_box.querySelectorAll("button"); + for (var i in images_history_tab_list){ + var tabname = images_history_tab_list[i] + tab_btns[i].setAttribute("tabname", tabname); + tab_btns[i].addEventListener('click', function(){ + var tabs_box = gradioApp().getElementById("images_history_tab"); + if (!tabs_box.classList.contains(this.getAttribute("tabname"))) { + gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click(); + tabs_box.classList.add(this.getAttribute("tabname")) + } + }); + } + tab_btns[0].click() } - var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div"); - tabs_box.setAttribute("id", "images_history_tab"); - var tab_btns = tabs_box.querySelectorAll("button"); - - for (var i in images_history_tab_list){ - var tabname = images_history_tab_list[i] - tab_btns[i].setAttribute("tabname", tabname); - // if (!init_status){ - // tab_btns[i].addEventListener('click', images_history_click_tab); - // } - tab_btns[i].addEventListener('click', images_history_click_tab); - } } else { setTimeout(images_history_init, 500); } } -var images_history_tab_list = ["custom", "txt2img", "img2img", "extras", "saved"]; +var images_history_tab_list = ""; setTimeout(images_history_init, 500); document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ - for (var i in images_history_tab_list ){ - let tabname = images_history_tab_list[i] - var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item'); - buttons.forEach(function(bnt){ - bnt.addEventListener('click', images_history_click_image, true); - }); - - var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); - if (cls_btn){ - cls_btn.addEventListener('click', function(){ - gradioApp().getElementById(tabname + '_images_history_del_button').setAttribute('disabled','disabled'); - }, false); - } - - } + if (images_history_tab_list != ""){ + for (var i in images_history_tab_list ){ + let tabname = images_history_tab_list[i] + var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item'); + buttons.forEach(function(bnt){ + bnt.addEventListener('click', images_history_click_image, true); + }); + + var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); + if (cls_btn){ + cls_btn.addEventListener('click', function(){ + gradioApp().getElementById(tabname + '_images_history_renew_page').click(); + }, false); + } + + } + } }); mutationObserver.observe(gradioApp(), { childList:true, subtree:true }); - }); diff --git a/modules/images_history.py b/modules/images_history.py index a40cdc0e..78fd0543 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -4,7 +4,9 @@ import time import hashlib import gradio system_bak_path = "webui_log_and_bak" -browser_tabname = "custom" +custom_tab_name = "custom fold" +faverate_tab_name = "favorites" +tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name] def is_valid_date(date): try: time.strptime(date, "%Y%m%d") @@ -122,7 +124,6 @@ def archive_images(dir_name, date_to): else: filenames = traverse_all_files(dir_name, filenames) total_num = len(filenames) - batch_count = len(filenames) + 1 // batch_size + 1 tmparray = [(os.path.getmtime(file), file) for file in filenames ] date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 filenames = [] @@ -146,10 +147,12 @@ def archive_images(dir_name, date_to): last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) date = date[:4] + "/" + date[4:6] + "/" + date[6:8] date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8] - load_info = f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" + load_info = "
" + load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" + load_info += "
" _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) return ( - gradio.Dropdown.update(choices=date_list, value=date_to), + date_to, load_info, filenames, 1, @@ -158,7 +161,7 @@ def archive_images(dir_name, date_to): "", visible_num, last_date_from, - #gradio.update(visible=batch_count > 1) + gradio.update(visible=total_num > num) ) def delete_image(delete_num, name, filenames, image_index, visible_num): @@ -209,7 +212,7 @@ def get_recent_images(page_index, step, filenames): visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num return page_index, image_list, "", "", visible_num -def newest_click(date_to): +def loac_batch_click(date_to): if date_to is None: return time.strftime("%Y%m%d",time.localtime(time.time())), [] else: @@ -248,7 +251,7 @@ def page_index_change(page_index, filenames): def show_image_info(tabname_box, num, page_index, filenames): file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] - tm = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + tm = "
" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "
" return file, tm, num, file def enable_page_buttons(): @@ -268,9 +271,9 @@ def change_dir(img_dir, date_to): warning = "The format of the directory is incorrect" if warning is None: today = time.strftime("%Y%m%d",time.localtime(time.time())) - return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today + return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True) else: - return gradio.update(visible=True), gradio.update(visible=False), warning, date_to + return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False) def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): custom_dir = False @@ -280,7 +283,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): dir_name = opts.outdir_img2img_samples elif tabname == "extras": dir_name = opts.outdir_extras_samples - elif tabname == "saved": + elif tabname == faverate_tab_name: dir_name = opts.outdir_save else: custom_dir = True @@ -295,22 +298,26 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): os.makedirs(dir_name) with gr.Column() as page_panel: - with gr.Row(): - img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) + with gr.Row(): + with gr.Column(scale=1, visible=not custom_dir) as load_batch_box: + load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True) + with gr.Column(scale=4): + with gr.Row(): + img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) + with gr.Row(): + with gr.Column(visible=False, scale=1) as batch_panel: + with gr.Row(): + forward = gr.Button('Prev batch') + backward = gr.Button('Next batch') + with gr.Column(scale=3): + load_info = gr.HTML(visible=not custom_dir) with gr.Row(visible=False) as warning: warning_box = gr.Textbox("Message", interactive=False) with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: - with gr.Column(scale=2): - with gr.Row() as batch_panel: - forward = gr.Button('Forward') - date_to = gr.Dropdown(label="Date to") - backward = gr.Button('Backward') - newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") - with gr.Row(): - load_info = gr.Textbox(show_label=False, interactive=False) - with gr.Row(visible=False) as turn_page_buttons: - renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") + with gr.Column(scale=2): + with gr.Row(visible=True) as turn_page_buttons: + #date_to = gr.Dropdown(label="Date to") first_page = gr.Button('First Page') prev_page = gr.Button('Prev Page') page_index = gr.Number(value=1, label="Page Index") @@ -322,50 +329,54 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - with gr.Column(): - with gr.Row(): - if tabname != "saved": - save_btn = gr.Button('Save') - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') + with gr.Column(): with gr.Row(): with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info", interactive=False) + img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6) + gr.HTML("
") img_file_name = gr.Textbox(value="", label="File Name", interactive=False) - img_file_time= gr.Textbox(value="", label="Create Time", interactive=False) - + img_file_time= gr.HTML() + with gr.Row(): + if tabname != faverate_tab_name: + save_btn = gr.Button('Collect') + pnginfo_send_to_txt2img = gr.Button('Send to txt2img') + pnginfo_send_to_img2img = gr.Button('Send to img2img') + - # hiden items - with gr.Row(visible=False): - visible_img_num = gr.Number() - date_to_recorder = gr.State([]) - last_date_from = gr.Textbox() - tabname_box = gr.Textbox(tabname) - image_index = gr.Textbox(value=-1) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") - filenames = gr.State() - all_images_list = gr.State() - hidden = gr.Image(type="pil") - info1 = gr.Textbox() - info2 = gr.Textbox() - - img_path.submit(change_dir, inputs=[img_path, date_to], outputs=[warning, main_panel, warning_box, date_to]) - #change date - change_date_output = [date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from] + # hiden items + with gr.Row(visible=False): + renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") + batch_date_to = gr.Textbox(label="Date to") + visible_img_num = gr.Number() + date_to_recorder = gr.State([]) + last_date_from = gr.Textbox() + tabname_box = gr.Textbox(tabname) + image_index = gr.Textbox(value=-1) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") + filenames = gr.State() + all_images_list = gr.State() + hidden = gr.Image(type="pil") + info1 = gr.Textbox() + info2 = gr.Textbox() + + img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info]) + + #change batch + change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel] - date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output) - date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) - date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output) + batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) + batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - newest.click(newest_click, inputs=[date_to], outputs=[date_to, date_to_recorder]) - forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) - backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) + load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder]) + forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) + backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) #delete delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None) - if tabname != "saved": + if tabname != faverate_tab_name: save_btn.click(save_image, inputs=[img_file_name], outputs=None) #turn page @@ -394,18 +405,20 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): -def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): +def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict): global opts; opts = sys_opts loads_files_num = int(opts.images_history_num_per_page) num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) + if cmp_ops.browse_all_images: + tabs_list.append(custom_tab_name) with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: - for tab in [browser_tabname, "txt2img", "img2img", "extras", "saved"]: + for tab in tabs_list: with gr.Tab(tab): with gr.Blocks(analytics_enabled=False) : - show_images_history(gr, opts, tab, run_pnginfo, switch_dict) - #gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False) - gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_finish_render", visible=False) - + show_images_history(gr, opts, tab, run_pnginfo, switch_dict) + gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False) + gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False) + return images_history diff --git a/modules/shared.py b/modules/shared.py index 1811018d..4d735414 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,6 +74,10 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) +parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False) + + +cmd_opts = parser.parse_args() cmd_opts = parser.parse_args() @@ -311,6 +315,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" options_templates.update(options_section(('images-history', "Images Browser"), { #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), + "images_history_preload": OptionInfo(False, "Preload images at startup"), "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"), "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "), "images_history_grid_num": OptionInfo(6, "Number of grids in each row"), diff --git a/modules/ui.py b/modules/ui.py index 85abac4d..88f46659 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1150,7 +1150,7 @@ def create_ui(wrap_gradio_gpu_call): "i2i":img2img_paste_fields } - images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): -- cgit v1.2.3 From abeec4b63029c2c4151a78fc395d312113881845 Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 03:18:26 -0700 Subject: Add auto focal point cropping to Preprocess images This algorithm plots a bunch of points of interest on the source image and averages their locations to find a center. Most points come from OpenCV. One point comes from an entropy model. OpenCV points account for 50% of the weight and the entropy based point is the other 50%. The center of all weighted points is calculated and a bounding box is drawn as close to centered over that point as possible. --- modules/textual_inversion/preprocess.py | 151 ++++++++++++++++++++++++++++++-- 1 file changed, 146 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 886cf0c3..168bfb09 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,5 +1,7 @@ import os -from PIL import Image, ImageOps +import cv2 +import numpy as np +from PIL import Image, ImageOps, ImageDraw import platform import sys import tqdm @@ -11,7 +13,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False): try: if process_caption: shared.interrogator.load() @@ -21,7 +23,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, process_entropy_focus) finally: @@ -33,7 +35,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ -def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False): width = process_width height = process_height src = os.path.abspath(process_src) @@ -93,6 +95,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro is_tall = ratio > 1.35 is_wide = ratio < 1 / 1.35 + processing_option_ran = False + if process_split and is_tall: img = img.resize((width, height * img.height // img.width)) @@ -101,6 +105,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro bot = img.crop((0, img.height - height, width, img.height)) save_pic(bot, index) + + processing_option_ran = True elif process_split and is_wide: img = img.resize((width * img.width // img.height, height)) @@ -109,8 +115,143 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro right = img.crop((img.width - width, 0, img.width, height)) save_pic(right, index) - else: + + processing_option_ran = True + + if process_entropy_focus and (is_tall or is_wide): + if is_tall: + img = img.resize((width, height * img.height // img.width)) + else: + img = img.resize((width * img.width // img.height, height)) + + x_focal_center, y_focal_center = image_central_focal_point(img, width, height) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(height / 2) + x_half = int(width / 2) + + x1 = x_focal_center - x_half + if x1 < 0: + x1 = 0 + elif x1 + width > img.width: + x1 = img.width - width + + y1 = y_focal_center - y_half + if y1 < 0: + y1 = 0 + elif y1 + height > img.height: + y1 = img.height - height + + x2 = x1 + width + y2 = y1 + height + + crop = [x1, y1, x2, y2] + + focal = img.crop(tuple(crop)) + save_pic(focal, index) + + processing_option_ran = True + + if not processing_option_ran: img = images.resize_image(1, img, width, height) save_pic(img, index) shared.state.nextjob() + + +def image_central_focal_point(im, target_width, target_height): + focal_points = [] + + focal_points.extend( + image_focal_points(im) + ) + + fp_entropy = image_entropy_point(im, target_width, target_height) + fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy + + focal_points.append(fp_entropy) + + weight = 0.0 + x = 0.0 + y = 0.0 + for focal_point in focal_points: + weight += focal_point['weight'] + x += focal_point['x'] * focal_point['weight'] + y += focal_point['y'] * focal_point['weight'] + avg_x = round(x // weight) + avg_y = round(y // weight) + + return avg_x, avg_y + + +def image_focal_points(im): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=50, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.05, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append({ + 'x': x, + 'y': y, + 'weight': 1.0 + }) + + return focal_points + + +def image_entropy_point(im, crop_width, crop_height): + img = im.copy() + # just make it easier to slide the test crop with images oriented the same way + if (img.size[0] < img.size[1]): + portrait = True + img = img.rotate(90, expand=1) + + e_max = 0 + crop_current = [0, 0, crop_width, crop_height] + crop_best = crop_current + while crop_current[2] < img.size[0]: + crop = img.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e_max < e): + e_max = e + crop_best = list(crop_current) + + crop_current[0] += 4 + crop_current[2] += 4 + + x_mid = int((crop_best[2] - crop_best[0])/2) + y_mid = int((crop_best[3] - crop_best[1])/2) + + return { + 'x': x_mid, + 'y': y_mid, + 'weight': 1.0 + } + + +def image_entropy(im): + # greyscale image entropy + band = np.asarray(im.convert("L")) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + -- cgit v1.2.3 From 087609ee181a91a523647435ffffa6288a317e2f Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 03:19:35 -0700 Subject: UI changes for focal point image cropping --- modules/ui.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 1ff7eb4f..b6be713b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1234,6 +1234,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') process_split = gr.Checkbox(label='Split oversized images into two') + process_entropy_focus = gr.Checkbox(label='Create auto focal point crop') process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) @@ -1318,7 +1319,8 @@ def create_ui(wrap_gradio_gpu_call): process_flip, process_split, process_caption, - process_caption_deepbooru + process_caption_deepbooru, + process_entropy_focus ], outputs=[ ti_output, -- cgit v1.2.3 From 019a3a88f07766f2d32c32fbe8e41625f28ecb5e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 19 Oct 2022 17:15:47 +0100 Subject: Update ui.py --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d2e24880..1573ef82 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1247,7 +1247,7 @@ def create_ui(wrap_gradio_gpu_call): run_preprocess = gr.Button(value="Preprocess", variant='primary') with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork wiki

") with gr.Row(): train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") -- cgit v1.2.3 From 2ce52d32e41fb523d1494f45073fd18496e52d35 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 16:31:12 +0000 Subject: fix for #3086 failing to load any previous hypernet --- modules/hypernetworks/hypernetwork.py | 60 ++++++++++++++++------------------- 1 file changed, 28 insertions(+), 32 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d519cd9..74300122 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,11 +24,10 @@ class HypernetworkModule(torch.nn.Module): def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): super().__init__() - if layer_structure is not None: - assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" - assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - else: - layer_structure = parse_layer_structure(dim, state_dict) + + assert layer_structure is not None, "layer_structure mut not be None" + assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" + assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" linears = [] for i in range(len(layer_structure) - 1): @@ -39,23 +38,30 @@ class HypernetworkModule(torch.nn.Module): self.linear = torch.nn.Sequential(*linears) if state_dict is not None: - try: - self.load_state_dict(state_dict) - except RuntimeError: - self.try_load_previous(state_dict) + self.fix_old_state_dict(state_dict) + self.load_state_dict(state_dict) else: for layer in self.linear: - layer.weight.data.normal_(mean = 0.0, std = 0.01) + layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() self.to(devices.device) - def try_load_previous(self, state_dict): - states = self.state_dict() - states['linear.0.bias'].copy_(state_dict['linear1.bias']) - states['linear.0.weight'].copy_(state_dict['linear1.weight']) - states['linear.1.bias'].copy_(state_dict['linear2.bias']) - states['linear.1.weight'].copy_(state_dict['linear2.weight']) + def fix_old_state_dict(self, state_dict): + changes = { + 'linear1.bias': 'linear.0.bias', + 'linear1.weight': 'linear.0.weight', + 'linear2.bias': 'linear.1.bias', + 'linear2.weight': 'linear.1.weight', + } + + for fr, to in changes.items(): + x = state_dict.get(fr, None) + if x is None: + continue + + del state_dict[fr] + state_dict[to] = x def forward(self, x): return x + self.linear(x) * self.multiplier @@ -71,18 +77,6 @@ def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength -def parse_layer_structure(dim, state_dict): - i = 0 - layer_structure = [1] - - while (key := "linear.{}.weight".format(i)) in state_dict: - weight = state_dict[key] - layer_structure.append(len(weight) // dim) - i += 1 - - return layer_structure - - class Hypernetwork: filename = None name = None @@ -135,17 +129,18 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') + self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) + self.add_layer_norm = state_dict.get('is_layer_norm', False) + for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]), - HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]), + HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm), ) self.name = state_dict.get('name', self.name) self.step = state_dict.get('step', 0) - self.layer_structure = state_dict.get('layer_structure', None) - self.add_layer_norm = state_dict.get('is_layer_norm', False) self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) @@ -244,6 +239,7 @@ def stack_conds(conds): return torch.stack(conds) + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert hypernetwork_name, 'hypernetwork not selected' -- cgit v1.2.3 From 14c1c2b9351f16d43ba4e6b6c9062edad44a6bec Mon Sep 17 00:00:00 2001 From: Alexandre Simard Date: Wed, 19 Oct 2022 13:53:52 -0400 Subject: Show PB texts at same time and earlier For big tasks (1000+ steps), waiting 1 minute to see ETA is long and this changes it so the number of steps done plays a role in showing the text as well. --- modules/ui.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a2dbd41e..0abd177a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -261,14 +261,14 @@ def wrap_gradio_call(func, extra_outputs=None): return f -def calc_time_left(progress, threshold, label, force_display): +def calc_time_left(progress, threshold, label, force_display, showTime): if progress == 0: return "" else: time_since_start = time.time() - shared.state.time_start eta = (time_since_start/progress) eta_relative = eta-time_since_start - if (eta_relative > threshold and progress > 0.02) or force_display: + if (eta_relative > threshold and showTime) or force_display: if eta_relative > 3600: return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) elif eta_relative > 60: @@ -290,7 +290,10 @@ def check_progress_call(id_part): if shared.state.sampling_steps > 0: progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display ) + # Show progress percentage and time left at the same moment, and base it also on steps done + showPBText = progress >= 0.01 or shared.state.sampling_step >= 10 + + time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display, showPBText ) if time_left != "": shared.state.time_left_force_display = True @@ -298,7 +301,7 @@ def check_progress_call(id_part): progressbar = "" if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}
""" + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if showPBText else ""}
""" image = gr_show(False) preview_visibility = gr_show(False) -- cgit v1.2.3 From eb7ba4b713ac2fb960ecf6365b1de0c89451e583 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 19 Oct 2022 19:50:46 +0100 Subject: update training header text --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 1573ef82..93c0767c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1247,7 +1247,7 @@ def create_ui(wrap_gradio_gpu_call): run_preprocess = gr.Button(value="Preprocess", variant='primary') with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork wiki

") + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork [wiki]

") with gr.Row(): train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") -- cgit v1.2.3 From 4d663055ded968831ec97f047dfa8e94036cf1c1 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 19 Oct 2022 20:33:18 +0100 Subject: update ui with extra training options --- modules/ui.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 93c0767c..cdb9d335 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1206,6 +1206,7 @@ def create_ui(wrap_gradio_gpu_call): new_embedding_name = gr.Textbox(label="Name") initialization_text = gr.Textbox(label="Initialization text", value="*") nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") with gr.Row(): with gr.Column(scale=3): @@ -1219,6 +1220,7 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") with gr.Row(): with gr.Column(scale=3): @@ -1247,14 +1249,17 @@ def create_ui(wrap_gradio_gpu_call): run_preprocess = gr.Button(value="Preprocess", variant='primary') with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork [wiki]

") + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") with gr.Row(): train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") with gr.Row(): train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") + with gr.Row(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") + batch_size = gr.Number(label='Batch size', value=1, precision=0) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") @@ -1288,6 +1293,7 @@ def create_ui(wrap_gradio_gpu_call): new_embedding_name, initialization_text, nvpt, + overwrite_old_embedding, ], outputs=[ train_embedding_name, @@ -1303,6 +1309,7 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_sizes, new_hypernetwork_layer_structure, new_hypernetwork_add_layer_norm, + overwrite_old_hypernetwork, ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From 41e3877be2c667316515c86037413763eb0ba4da Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 13:44:59 -0700 Subject: fix entropy point calculation --- modules/textual_inversion/preprocess.py | 34 ++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 168bfb09..7c1a594e 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -196,9 +196,9 @@ def image_focal_points(im): points = cv2.goodFeaturesToTrack( np_im, - maxCorners=50, + maxCorners=100, qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.05, + minDistance=min(grayscale.width, grayscale.height)*0.07, useHarrisDetector=False, ) @@ -218,28 +218,32 @@ def image_focal_points(im): def image_entropy_point(im, crop_width, crop_height): - img = im.copy() - # just make it easier to slide the test crop with images oriented the same way - if (img.size[0] < img.size[1]): - portrait = True - img = img.rotate(90, expand=1) + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] e_max = 0 crop_current = [0, 0, crop_width, crop_height] crop_best = crop_current - while crop_current[2] < img.size[0]: - crop = img.crop(tuple(crop_current)) + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) e = image_entropy(crop) - if (e_max < e): + if (e > e_max): e_max = e crop_best = list(crop_current) - crop_current[0] += 4 - crop_current[2] += 4 + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + crop_width/2) + y_mid = int(crop_best[1] + crop_height/2) - x_mid = int((crop_best[2] - crop_best[0])/2) - y_mid = int((crop_best[3] - crop_best[1])/2) return { 'x': x_mid, @@ -250,7 +254,7 @@ def image_entropy_point(im, crop_width, crop_height): def image_entropy(im): # greyscale image entropy - band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1")) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() -- cgit v1.2.3 From 8e7097d06a6a261580d34375c9d2a9e4ffc63ffa Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 13:47:45 -0700 Subject: Added support for RunwayML inpainting model --- modules/processing.py | 34 ++++++- modules/sd_hijack_inpainting.py | 208 ++++++++++++++++++++++++++++++++++++++++ modules/sd_models.py | 16 +++- modules/sd_samplers.py | 50 +++++++--- 4 files changed, 293 insertions(+), 15 deletions(-) create mode 100644 modules/sd_hijack_inpainting.py (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index bcb0c32c..a6c308f9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -546,7 +546,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) + + # The "masked-image" in this case will just be all zeros since the entire image is masked. + image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device) + image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=image_conditioning) return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -714,10 +723,31 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask + if self.image_mask is not None: + conditioning_mask = np.array(self.image_mask.convert("L")) + conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 + conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: + conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) + + # Create another latent image, this time with a masked version of the original input. + conditioning_mask = conditioning_mask.to(image.device) + conditioning_image = image * (1.0 - conditioning_mask) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + + # Create the concatenated conditioning tensor to be fed to `c_concat` + conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) + conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) + self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) + self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning) + samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py new file mode 100644 index 00000000..7e5670d6 --- /dev/null +++ b/modules/sd_hijack_inpainting.py @@ -0,0 +1,208 @@ +import torch +import numpy as np + +from tqdm import tqdm +from einops import rearrange, repeat +from omegaconf import ListConfig + +from types import MethodType + +import ldm.models.diffusion.ddpm +import ldm.models.diffusion.ddim + +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.models.diffusion.ddim import DDIMSampler, noise_like + +# ================================================================================================= +# Monkey patch DDIMSampler methods from RunwayML repo directly. +# Adapted from: +# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py +# ================================================================================================= +@torch.no_grad() +def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = elf.inpainting_fill == 2: + self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask + elif self.inpainting_fill == 3: + self.init_latent = self.init_latent * self.mask + + if self.image_mask is not None: + conditioning_mask = np.array(self.image_mask.convert("L")) + conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 + conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: + conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) + + # Create another latent image, this time with a masked version of the original input. + conditioning_mask = conditioning_mask.to(image.device) + conditioning_image = image * (1.0 - conditioning_mask) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + + # Create the concatenated conditioning tensor to be fed to `c_concat` + conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) + conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) + self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) + self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) + + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + x = create_random_tensors([opctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + +@torch.no_grad() +def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + +# ================================================================================================= +# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config. +# Adapted from: +# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py +# ================================================================================================= + +@torch.no_grad() +def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) + return c + +class LatentInpaintDiffusion(LatentDiffusion): + def __init__( + self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.concat_keys = concat_keys + +def should_hijack_inpainting(checkpoint_info): + return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml") + +def do_inpainting_hijack(): + ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning + ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion + ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim + ldm.models.diffusion.ddim.DDIMSampler.sample = sample \ No newline at end of file diff --git a/modules/sd_models.py b/modules/sd_models.py index eae22e87..47836d25 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config from modules import shared, modelloader, devices from modules.paths import models_path +from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -211,6 +212,19 @@ def load_model(): print(f"Loading config from: {checkpoint_info.config}") sd_config = OmegaConf.load(checkpoint_info.config) + + if should_hijack_inpainting(checkpoint_info): + do_inpainting_hijack() + + # Hardcoded config for now... + sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + sd_config.model.params.use_ema = False + sd_config.model.params.conditioning_key = "hybrid" + sd_config.model.params.unet_config.params.in_channels = 9 + + # Create a "fake" config with a different name so that we know to unload it when switching models. + checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) @@ -234,7 +248,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() shared.sd_model = load_model() return shared.sd_model diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index b58e810b..9d3cf289 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -136,9 +136,15 @@ class VanillaStableDiffusionSampler: if self.stop_at is not None and self.step > self.stop_at: raise InterruptedException + # Have to unwrap the inpainting conditioning here to perform pre-preocessing + image_conditioning = None + if isinstance(cond, dict): + image_conditioning = cond["c_concat"][0] + cond = cond["c_crossattn"][0] + unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' cond = tensor @@ -157,6 +163,10 @@ class VanillaStableDiffusionSampler: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec + if image_conditioning is not None: + cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) if self.mask is not None: @@ -182,7 +192,7 @@ class VanillaStableDiffusionSampler: self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = setup_img2img_steps(p, steps) self.initialize(p) @@ -202,7 +212,7 @@ class VanillaStableDiffusionSampler: return samples - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): self.initialize(p) self.init_latent = None @@ -210,6 +220,11 @@ class VanillaStableDiffusionSampler: steps = steps or p.steps + # Wrap the conditioning models with additional image conditioning for inpainting model + if image_conditioning is not None: + conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + # existing code fails with certain step counts, like 9 try: samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) @@ -228,7 +243,7 @@ class CFGDenoiser(torch.nn.Module): self.init_latent = None self.step = 0 - def forward(self, x, sigma, uncond, cond, cond_scale): + def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): if state.interrupted or state.skipped: raise InterruptedException @@ -239,28 +254,29 @@ class CFGDenoiser(torch.nn.Module): repeats = [len(conds_list[i]) for i in range(batch_size)] x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) if tensor.shape[1] == uncond.shape[1]: cond_in = torch.cat([tensor, uncond]) if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=cond_in) + x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size for batch_offset in range(0, tensor.shape[0], batch_size): a = batch_offset b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond) + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) @@ -361,7 +377,7 @@ class KDiffusionSampler: return extra_params_kwargs - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = setup_img2img_steps(p, steps) if p.sampler_noise_scheduler_override: @@ -389,11 +405,16 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): steps = steps or p.steps if p.sampler_noise_scheduler_override: @@ -414,7 +435,12 @@ class KDiffusionSampler: else: extra_params_kwargs['sigmas'] = sigmas - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples -- cgit v1.2.3 From 0719c10bf1b817364a498ee11b90d30d3d527344 Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 13:56:26 -0700 Subject: Fixed copying mistake --- modules/sd_hijack_inpainting.py | 79 +++++++++++++---------------------------- 1 file changed, 25 insertions(+), 54 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 7e5670d6..d4d28d2e 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -19,63 +19,35 @@ from ldm.models.diffusion.ddim import DDIMSampler, noise_like # https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py # ================================================================================================= @torch.no_grad() -def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): +def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] while isinstance(ctmp, list): - ctmp = elf.inpainting_fill == 2: - self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask - elif self.inpainting_fill == 3: - self.init_latent = self.init_latent * self.mask - - if self.image_mask is not None: - conditioning_mask = np.array(self.image_mask.convert("L")) - conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 - conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) - - # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 - conditioning_mask = torch.round(conditioning_mask) - else: - conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) - - # Create another latent image, this time with a masked version of the original input. - conditioning_mask = conditioning_mask.to(image.device) - conditioning_image = image * (1.0 - conditioning_mask) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) - - # Create the concatenated conditioning tensor to be fed to `c_concat` - conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) - conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) - self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) - self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) - - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - x = create_random_tensors([opctmp[0] + ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") @@ -106,7 +78,6 @@ def sample( ) return samples, intermediates - @torch.no_grad() def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, -- cgit v1.2.3 From dde9f960727bfe151d418e43685a2881cf580a17 Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 14:14:24 -0700 Subject: added support for ddim img2img --- modules/sd_samplers.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 9d3cf289..d270e4df 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -208,6 +208,12 @@ class VanillaStableDiffusionSampler: self.init_latent = x self.step = 0 + # Wrap the conditioning models with additional image conditioning for inpainting model + if image_conditioning is not None: + conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + + samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) return samples -- cgit v1.2.3 From c418467c03db916c3e5312e6ac4a67365e196dbd Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 15:09:43 -0700 Subject: Don't compute latent mask if were not using it. Also added support for fixed highres_fix generation. --- modules/processing.py | 72 +++++++++++++++++++++++++++++++------------------- modules/sd_samplers.py | 4 +++ 2 files changed, 49 insertions(+), 27 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index a6c308f9..684e5833 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -541,12 +541,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) - - if not self.enable_hr: - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - + def create_dummy_mask(self, x): + if self.sampler.conditioning_key in {'hybrid', 'concat'}: # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device) image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) @@ -555,11 +551,23 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) image_conditioning = image_conditioning.to(x.dtype) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=image_conditioning) + else: + # Dummy zero conditioning if we're not using inpainting model. + # Still takes up a bit of memory, but no encoder call. + image_conditioning = torch.zeros(x.shape[0], 5, x.shape[-2], x.shape[-1], dtype=x.dtype, device=x.device) + + return image_conditioning + + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) + + if not self.enable_hr: + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x)) return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x)) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] @@ -596,7 +604,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples)) return samples @@ -723,26 +731,36 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - if self.image_mask is not None: - conditioning_mask = np.array(self.image_mask.convert("L")) - conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 - conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + conditioning_key = self.sampler.conditioning_key - # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 - conditioning_mask = torch.round(conditioning_mask) + if conditioning_key in {'hybrid', 'concat'}: + if self.image_mask is not None: + conditioning_mask = np.array(self.image_mask.convert("L")) + conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 + conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: + conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) + + # Create another latent image, this time with a masked version of the original input. + conditioning_mask = conditioning_mask.to(image.device) + conditioning_image = image * (1.0 - conditioning_mask) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + + # Create the concatenated conditioning tensor to be fed to `c_concat` + conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) + conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) + self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) + self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) else: - conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) - - # Create another latent image, this time with a masked version of the original input. - conditioning_mask = conditioning_mask.to(image.device) - conditioning_image = image * (1.0 - conditioning_mask) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) - - # Create the concatenated conditioning tensor to be fed to `c_concat` - conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) - conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) - self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) - self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) + self.image_conditioning = torch.zeros( + self.init_latent.shape[0], 5, self.init_latent.shape[-2], self.init_latent.shape[-1], + dtype=self.init_latent.dtype, + device=self.init_latent.device + ) + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d270e4df..c21be26e 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler: self.config = None self.last_latent = None + self.conditioning_key = sd_model.model.conditioning_key + def number_of_needed_noises(self, p): return 0 @@ -328,6 +330,8 @@ class KDiffusionSampler: self.config = None self.last_latent = None + self.conditioning_key = sd_model.model.conditioning_key + def callback_state(self, d): step = d['i'] latent = d["denoised"] -- cgit v1.2.3 From d6ea5841374a28f3f6deb73abc251c8f0bcb240f Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:07:57 +0100 Subject: change html output --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d519cd9..73c1cb80 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -380,7 +380,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log Loss: {mean_loss:.7f}
Step: {hypernetwork.step}
Last prompt: {html.escape(entries[0].cond_text)}
-Last saved embedding: {html.escape(last_saved_file)}
+Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" -- cgit v1.2.3 From 166be3919b817cee5e702fd01c34afe9081b952c Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:09:40 +0100 Subject: allow overwrite old hn --- modules/hypernetworks/ui.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 08f75f15..f45345ea 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -10,9 +10,10 @@ from modules import sd_hijack, shared, devices from modules.hypernetworks import hypernetwork -def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False): +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") - assert not os.path.exists(fn), f"file {fn} already exists" + if not overwrite_old: + assert not os.path.exists(fn), f"file {fn} already exists" if type(layer_structure) == str: layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure))) -- cgit v1.2.3 From 0087079c2d487b67b06ffc30f36ce486a74e6318 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:10:59 +0100 Subject: allow overwrite old embedding --- modules/textual_inversion/textual_inversion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 3be69562..5776778b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -153,7 +153,7 @@ class EmbeddingDatabase: return None, None -def create_embedding(name, num_vectors_per_token, init_text='*'): +def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): cond_model = shared.sd_model.cond_stage_model embedding_layer = cond_model.wrapped.transformer.text_model.embeddings @@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") - assert not os.path.exists(fn), f"file {fn} already exists" + if not overwrite_old: + assert not os.path.exists(fn), f"file {fn} already exists" embedding = Embedding(vec, name) embedding.step = 0 -- cgit v1.2.3 From 632e8d660293081cadb145d8062e5aff0a4a8f0d Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:19:40 +0100 Subject: split learn rates --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index cdb9d335..d07184ee 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1342,7 +1342,7 @@ def create_ui(wrap_gradio_gpu_call): _js="start_training_textual_inversion", inputs=[ train_embedding_name, - learn_rate, + embedding_learn_rate, batch_size, dataset_directory, log_directory, @@ -1367,7 +1367,7 @@ def create_ui(wrap_gradio_gpu_call): _js="start_training_textual_inversion", inputs=[ train_hypernetwork_name, - learn_rate, + hypernetwork_learn_rate, batch_size, dataset_directory, log_directory, -- cgit v1.2.3 From c3835ec85cbb44fa3c46fa871c622b6fee235c89 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:24:24 +0100 Subject: pass overwrite old flag --- modules/textual_inversion/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 36881e7a..e712284d 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess from modules import sd_hijack, shared -def create_embedding(name, initialization_text, nvpt): - filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text) +def create_embedding(name, initialization_text, nvpt, overwrite_old): + filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() -- cgit v1.2.3 From 4d6b9f76a55fd0ac0f72634071032dd9c6efb409 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:27:16 +0100 Subject: reorder create_hypernetwork params --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d07184ee..322c082b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1307,9 +1307,9 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ new_hypernetwork_name, new_hypernetwork_sizes, + overwrite_old_hypernetwork, new_hypernetwork_layer_structure, new_hypernetwork_add_layer_norm, - overwrite_old_hypernetwork, ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From fbcce66601994f6ed370db36d9c238840fed6bd2 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:46:54 +0100 Subject: add existing caption file handling --- modules/textual_inversion/preprocess.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 886cf0c3..5c43fe13 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro shared.state.textinfo = "Preprocessing..." shared.state.job_count = len(files) - def save_pic_with_caption(image, index): + def save_pic_with_caption(image, index, existing_caption=None): caption = "" if process_caption: @@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro basename = f"{index:05}-{subindex[0]}-{filename_part}" image.save(os.path.join(dst, f"{basename}.png")) + if preprocess_txt_action == 'prepend' and existing_caption: + caption = existing_caption + ' ' + caption + elif preprocess_txt_action == 'append' and existing_caption: + caption = caption + ' ' + existing_caption + elif preprocess_txt_action == 'copy' and existing_caption: + caption = existing_caption + + caption = caption.strip() + if len(caption) > 0: with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: file.write(caption) subindex[0] += 1 - def save_pic(image, index): + def save_pic(image, index, existing_caption=None): save_pic_with_caption(image, index) if process_flip: - save_pic_with_caption(ImageOps.mirror(image), index) + save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption) for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] @@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro except Exception: continue + existing_caption = None + + try: + existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read() + except Exception as e: + print(e) + if shared.state.interrupted: break @@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro img = img.resize((width, height * img.height // img.width)) top = img.crop((0, 0, width, height)) - save_pic(top, index) + save_pic(top, index, existing_caption=existing_caption) bot = img.crop((0, img.height - height, width, img.height)) - save_pic(bot, index) + save_pic(bot, index, existing_caption=existing_caption) elif process_split and is_wide: img = img.resize((width * img.width // img.height, height)) left = img.crop((0, 0, width, height)) - save_pic(left, index) + save_pic(left, index, existing_caption=existing_caption) right = img.crop((img.width - width, 0, img.width, height)) - save_pic(right, index) + save_pic(right, index, existing_caption=existing_caption) else: img = images.resize_image(1, img, width, height) - save_pic(img, index) + save_pic(img, index, existing_caption=existing_caption) shared.state.nextjob() -- cgit v1.2.3 From ab353b141df8eee042b0964bcb645015dabf3459 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:48:07 +0100 Subject: link existing txt option --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 322c082b..7f52ac0c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1234,6 +1234,7 @@ def create_ui(wrap_gradio_gpu_call): process_dst = gr.Textbox(label='Destination directory') process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', choices=['ignore', 'copy', 'prepend', 'append']) with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') @@ -1326,6 +1327,7 @@ def create_ui(wrap_gradio_gpu_call): process_dst, process_width, process_height, + preprocess_txt_action, process_flip, process_split, process_caption, -- cgit v1.2.3 From 9b65c4ecf4f8eb6187ee721918adebe68e9bc631 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:49:23 +0100 Subject: pass preprocess_txt_action param --- modules/textual_inversion/preprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 5c43fe13..3713bc89 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False): try: if process_caption: shared.interrogator.load() @@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru) finally: @@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ -def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False): width = process_width height = process_height src = os.path.abspath(process_src) -- cgit v1.2.3 From 55d8c6cce6d3aef848b9f194adad2ce53064d8b7 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 00:53:29 +0100 Subject: default to ignore existing captions --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 7f52ac0c..bd5f1b05 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1234,7 +1234,7 @@ def create_ui(wrap_gradio_gpu_call): process_dst = gr.Textbox(label='Destination directory') process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', choices=['ignore', 'copy', 'prepend', 'append']) + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') -- cgit v1.2.3 From 8b74b9aa9a20e4c5c1f72641f8b9617479eb276b Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Wed, 19 Oct 2022 19:06:14 -0500 Subject: add symbol for clear button and simplify roll_col css selector --- modules/ui.py | 2 ++ style.css | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a2dbd41e..9f6edc5f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -83,6 +83,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 +trash_prompt_symbol = '\U0001F5D1' # 🗑🗑🗑 def plaintext_to_html(text): @@ -498,6 +499,7 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt") token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") diff --git a/style.css b/style.css index 26ae36a5..21a8911f 100644 --- a/style.css +++ b/style.css @@ -114,7 +114,7 @@ padding: 0.4em 0; } -#roll, #paste, #style_create, #style_apply{ +#roll_col > button { min-width: 2em; min-height: 2em; max-width: 2em; -- cgit v1.2.3 From 6f98e89486f55b0e4657e96ce640cf1c4675d187 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Thu, 20 Oct 2022 00:10:45 +0000 Subject: update --- modules/hypernetworks/hypernetwork.py | 29 +++++++++++++++-------- modules/hypernetworks/ui.py | 3 ++- modules/ui.py | 43 +++++++++++++++++++---------------- 3 files changed, 44 insertions(+), 31 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74300122..7d617680 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -22,16 +22,20 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): + def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): super().__init__() - assert layer_structure is not None, "layer_structure mut not be None" + assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" linears = [] for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + if activation_func == "relu": + linears.append(torch.nn.ReLU()) + if activation_func == "leakyrelu": + linears.append(torch.nn.LeakyReLU()) if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -42,8 +46,9 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - layer.weight.data.normal_(mean=0.0, std=0.01) - layer.bias.data.zero_() + if not "ReLU" in layer.__str__(): + layer.weight.data.normal_(mean=0.0, std=0.01) + layer.bias.data.zero_() self.to(devices.device) @@ -69,7 +74,8 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - layer_structure += [layer.weight, layer.bias] + if not "ReLU" in layer.__str__(): + layer_structure += [layer.weight, layer.bias] return layer_structure @@ -81,7 +87,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None): self.filename = None self.name = name self.layers = {} @@ -90,11 +96,12 @@ class Hypernetwork: self.sd_checkpoint_name = None self.layer_structure = layer_structure self.add_layer_norm = add_layer_norm + self.activation_func = activation_func for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), ) def weights(self): @@ -117,6 +124,7 @@ class Hypernetwork: state_dict['name'] = self.name state_dict['layer_structure'] = self.layer_structure state_dict['is_layer_norm'] = self.add_layer_norm + state_dict['activation_func'] = self.activation_func state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name @@ -131,12 +139,13 @@ class Hypernetwork: self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) self.add_layer_norm = state_dict.get('is_layer_norm', False) + self.activation_func = state_dict.get('activation_func', None) for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm), - HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func), + HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func), ) self.name = state_dict.get('name', self.name) diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 08f75f15..83f9547b 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -10,7 +10,7 @@ from modules import sd_hijack, shared, devices from modules.hypernetworks import hypernetwork -def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False): +def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" @@ -22,6 +22,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm enable_sizes=[int(x) for x in enable_sizes], layer_structure=layer_structure, add_layer_norm=add_layer_norm, + activation_func=activation_func, ) hypernet.save(fn) diff --git a/modules/ui.py b/modules/ui.py index d2e24880..8751fa9c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,43 +5,44 @@ import json import math import mimetypes import os +import platform import random +import subprocess as sp import sys import tempfile import time import traceback -import platform -import subprocess as sp from functools import partial, reduce +import gradio as gr +import gradio.routes +import gradio.utils import numpy as np +import piexif import torch from PIL import Image, PngImagePlugin -import piexif -import gradio as gr -import gradio.utils -import gradio.routes - -from modules import sd_hijack, sd_models, localization +from modules import localization, sd_hijack, sd_models from modules.paths import script_path -from modules.shared import opts, cmd_opts, restricted_opts +from modules.shared import cmd_opts, opts, restricted_opts + if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags -import modules.shared as shared -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.sd_hijack import model_hijack + +import modules.codeformer_model +import modules.generation_parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.images_history as img_his import modules.ldsr_model import modules.scripts -import modules.gfpgan_model -import modules.codeformer_model +import modules.shared as shared import modules.styles -import modules.generation_parameters_copypaste +import modules.textual_inversion.ui from modules import prompt_parser from modules.images import save_image -import modules.textual_inversion.ui -import modules.hypernetworks.ui -import modules.images_history as img_his +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -268,8 +269,8 @@ def calc_time_left(progress, threshold, label, force_display): time_since_start = time.time() - shared.state.time_start eta = (time_since_start/progress) eta_relative = eta-time_since_start - if (eta_relative > threshold and progress > 0.02) or force_display: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + if (eta_relative > threshold and progress > 0.02) or force_display: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) else: return "" @@ -1219,6 +1220,7 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") + new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["relu", "leakyrelu"]) with gr.Row(): with gr.Column(scale=3): @@ -1303,6 +1305,7 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_sizes, new_hypernetwork_layer_structure, new_hypernetwork_add_layer_norm, + new_hypernetwork_activation_func, ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From ba469343e6a1c6e23e82acf5feb65c6101dacbb2 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Thu, 20 Oct 2022 00:17:04 +0000 Subject: align ui.py imports with upstream --- modules/ui.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 987b1d7d..913b23b4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,44 +5,43 @@ import json import math import mimetypes import os -import platform import random -import subprocess as sp import sys import tempfile import time import traceback +import platform +import subprocess as sp from functools import partial, reduce -import gradio as gr -import gradio.routes -import gradio.utils import numpy as np -import piexif import torch from PIL import Image, PngImagePlugin +import piexif -from modules import localization, sd_hijack, sd_models -from modules.paths import script_path -from modules.shared import cmd_opts, opts, restricted_opts +import gradio as gr +import gradio.utils +import gradio.routes +from modules import sd_hijack, sd_models, localization +from modules.paths import script_path +from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags - -import modules.codeformer_model -import modules.generation_parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.images_history as img_his +import modules.shared as shared +from modules.sd_samplers import samplers, samplers_for_img2img +from modules.sd_hijack import model_hijack import modules.ldsr_model import modules.scripts -import modules.shared as shared +import modules.gfpgan_model +import modules.codeformer_model import modules.styles -import modules.textual_inversion.ui +import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img +import modules.textual_inversion.ui +import modules.hypernetworks.ui +import modules.images_history as img_his # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() -- cgit v1.2.3 From 59ed74438318af893d2cba552b0e28dbc2a9266c Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 17:19:02 -0700 Subject: face detection algo, configurability, reusability Try to move the crop in the direction of a face if it is present More internal configuration options for choosing weights of each of the algorithm's findings Move logic into its module --- modules/textual_inversion/autocrop.py | 216 ++++++++++++++++++++++++++++++++ modules/textual_inversion/preprocess.py | 150 +++------------------- 2 files changed, 230 insertions(+), 136 deletions(-) create mode 100644 modules/textual_inversion/autocrop.py (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py new file mode 100644 index 00000000..f858a958 --- /dev/null +++ b/modules/textual_inversion/autocrop.py @@ -0,0 +1,216 @@ +import cv2 +from collections import defaultdict +from math import log, sqrt +import numpy as np +from PIL import Image, ImageDraw + +GREEN = "#0F0" +BLUE = "#00F" +RED = "#F00" + +def crop_image(im, settings): + """ Intelligently crop an image to the subject matter """ + if im.height > im.width: + im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) + else: + im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) + + focus = focal_point(im, settings) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(settings.crop_height / 2) + x_half = int(settings.crop_width / 2) + + x1 = focus.x - x_half + if x1 < 0: + x1 = 0 + elif x1 + settings.crop_width > im.width: + x1 = im.width - settings.crop_width + + y1 = focus.y - y_half + if y1 < 0: + y1 = 0 + elif y1 + settings.crop_height > im.height: + y1 = im.height - settings.crop_height + + x2 = x1 + settings.crop_width + y2 = y1 + settings.crop_height + + crop = [x1, y1, x2, y2] + + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = list(crop) + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + + return im.crop(tuple(crop)) + +def focal_point(im, settings): + corner_points = image_corner_points(im, settings) + entropy_points = image_entropy_points(im, settings) + face_points = image_face_points(im, settings) + + total_points = len(corner_points) + len(entropy_points) + len(face_points) + + corner_weight = settings.corner_points_weight + entropy_weight = settings.entropy_points_weight + face_weight = settings.face_points_weight + + weight_pref_total = corner_weight + entropy_weight + face_weight + + # weight things + pois = [] + if weight_pref_total == 0 or total_points == 0: + return pois + + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] + ) + + if settings.annotate_image: + d = ImageDraw.Draw(im) + + average_point = poi_average(pois, settings, im=im) + + if settings.annotate_image: + d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN) + + return average_point + + +def image_face_points(im, settings): + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml') + + minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side + faces = classifier.detectMultiScale(gray, scaleFactor=1.05, + minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + + if len(faces) == 0: + return [] + + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + if settings.annotate_image: + for f in rects: + d = ImageDraw.Draw(im) + d.rectangle(f, outline=RED) + + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects] + + +def image_corner_points(im, settings): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=100, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.07, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y)) + + return focal_points + + +def image_entropy_points(im, settings): + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] + else: + return [] + + e_max = 0 + crop_current = [0, 0, settings.crop_width, settings.crop_height] + crop_best = crop_current + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e > e_max): + e_max = e + crop_best = list(crop_current) + + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + settings.crop_width/2) + y_mid = int(crop_best[1] + settings.crop_height/2) + + return [PointOfInterest(x_mid, y_mid)] + + +def image_entropy(im): + # greyscale image entropy + band = np.asarray(im.convert("1")) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + + +def poi_average(pois, settings, im=None): + weight = 0.0 + x = 0.0 + y = 0.0 + for pois in pois: + if settings.annotate_image and im is not None: + w = 4 * 0.5 * sqrt(pois.weight) + d = ImageDraw.Draw(im) + d.ellipse([ + pois.x - w, pois.y - w, + pois.x + w, pois.y + w ], fill=BLUE) + weight += pois.weight + x += pois.x * pois.weight + y += pois.y * pois.weight + avg_x = round(x / weight) + avg_y = round(y / weight) + + return PointOfInterest(avg_x, avg_y) + + +class PointOfInterest: + def __init__(self, x, y, weight=1.0): + self.x = x + self.y = y + self.weight = weight + + +class Settings: + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): + self.crop_width = crop_width + self.crop_height = crop_height + self.corner_points_weight = corner_points_weight + self.entropy_points_weight = entropy_points_weight + self.face_points_weight = entropy_points_weight + self.annotate_image = annotate_image + self.destop_view_image = False \ No newline at end of file diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 7c1a594e..0c79f012 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,7 +1,5 @@ import os -import cv2 -import numpy as np -from PIL import Image, ImageOps, ImageDraw +from PIL import Image, ImageOps import platform import sys import tqdm @@ -9,6 +7,7 @@ import time from modules import shared, images from modules.shared import opts, cmd_opts +from modules.textual_inversion import autocrop if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru @@ -80,6 +79,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro if process_flip: save_pic_with_caption(ImageOps.mirror(image), index) + for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] filename = os.path.join(src, imagefile) @@ -118,37 +118,16 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro processing_option_ran = True - if process_entropy_focus and (is_tall or is_wide): - if is_tall: - img = img.resize((width, height * img.height // img.width)) - else: - img = img.resize((width * img.width // img.height, height)) - - x_focal_center, y_focal_center = image_central_focal_point(img, width, height) - - # take the focal point and turn it into crop coordinates that try to center over the focal - # point but then get adjusted back into the frame - y_half = int(height / 2) - x_half = int(width / 2) - - x1 = x_focal_center - x_half - if x1 < 0: - x1 = 0 - elif x1 + width > img.width: - x1 = img.width - width - - y1 = y_focal_center - y_half - if y1 < 0: - y1 = 0 - elif y1 + height > img.height: - y1 = img.height - height - - x2 = x1 + width - y2 = y1 + height - - crop = [x1, y1, x2, y2] - - focal = img.crop(tuple(crop)) + if process_entropy_focus and img.height != img.width: + autocrop_settings = autocrop.Settings( + crop_width = width, + crop_height = height, + face_points_weight = 0.9, + entropy_points_weight = 0.7, + corner_points_weight = 0.5, + annotate_image = False + ) + focal = autocrop.crop_image(img, autocrop_settings) save_pic(focal, index) processing_option_ran = True @@ -157,105 +136,4 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro img = images.resize_image(1, img, width, height) save_pic(img, index) - shared.state.nextjob() - - -def image_central_focal_point(im, target_width, target_height): - focal_points = [] - - focal_points.extend( - image_focal_points(im) - ) - - fp_entropy = image_entropy_point(im, target_width, target_height) - fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy - - focal_points.append(fp_entropy) - - weight = 0.0 - x = 0.0 - y = 0.0 - for focal_point in focal_points: - weight += focal_point['weight'] - x += focal_point['x'] * focal_point['weight'] - y += focal_point['y'] * focal_point['weight'] - avg_x = round(x // weight) - avg_y = round(y // weight) - - return avg_x, avg_y - - -def image_focal_points(im): - grayscale = im.convert("L") - - # naive attempt at preventing focal points from collecting at watermarks near the bottom - gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") - - np_im = np.array(grayscale) - - points = cv2.goodFeaturesToTrack( - np_im, - maxCorners=100, - qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.07, - useHarrisDetector=False, - ) - - if points is None: - return [] - - focal_points = [] - for point in points: - x, y = point.ravel() - focal_points.append({ - 'x': x, - 'y': y, - 'weight': 1.0 - }) - - return focal_points - - -def image_entropy_point(im, crop_width, crop_height): - landscape = im.height < im.width - portrait = im.height > im.width - if landscape: - move_idx = [0, 2] - move_max = im.size[0] - elif portrait: - move_idx = [1, 3] - move_max = im.size[1] - - e_max = 0 - crop_current = [0, 0, crop_width, crop_height] - crop_best = crop_current - while crop_current[move_idx[1]] < move_max: - crop = im.crop(tuple(crop_current)) - e = image_entropy(crop) - - if (e > e_max): - e_max = e - crop_best = list(crop_current) - - crop_current[move_idx[0]] += 4 - crop_current[move_idx[1]] += 4 - - x_mid = int(crop_best[0] + crop_width/2) - y_mid = int(crop_best[1] + crop_height/2) - - - return { - 'x': x_mid, - 'y': y_mid, - 'weight': 1.0 - } - - -def image_entropy(im): - # greyscale image entropy - band = np.asarray(im.convert("1")) - hist, _ = np.histogram(band, bins=range(0, 256)) - hist = hist[hist > 0] - return -np.log2(hist / hist.sum()).sum() - + shared.state.nextjob() \ No newline at end of file -- cgit v1.2.3 From 858462f719c22ca9f24b94a41699653c34b5f4fb Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 20 Oct 2022 02:57:18 +0100 Subject: do caption copy for both flips --- modules/textual_inversion/preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 3713bc89..6bba3852 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -82,7 +82,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre subindex[0] += 1 def save_pic(image, index, existing_caption=None): - save_pic_with_caption(image, index) + save_pic_with_caption(image, index, existing_caption=existing_caption) if process_flip: save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption) -- cgit v1.2.3 From c6345bd445463b7aa41723d6637e80dfa293a890 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Wed, 19 Oct 2022 21:23:57 -0500 Subject: nerf line length --- modules/ui.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 9f6edc5f..cb9a6c6e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -83,7 +83,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 -trash_prompt_symbol = '\U0001F5D1' # 🗑🗑🗑 +trash_prompt_symbol = '\U0001F5D1' # def plaintext_to_html(text): @@ -617,7 +617,10 @@ def create_ui(wrap_gradio_gpu_call): return refresh_button with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\ + txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\ + token_button = create_toprow(is_img2img=False) + dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.3 From aa7ff2a1972f3865883e10ba28c5414cdebe8e3b Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 21:46:13 -0700 Subject: Fixed non-square highres fix generation --- modules/processing.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 684e5833..3caac25e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -541,10 +541,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def create_dummy_mask(self, x): + def create_dummy_mask(self, x, first_phase: bool = False): if self.sampler.conditioning_key in {'hybrid', 'concat'}: + height = self.firstphase_height if first_phase else self.height + width = self.firstphase_width if first_phase else self.width + # The "masked-image" in this case will just be all zeros since the entire image is masked. - image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device) + image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) # Add the fake full 1s mask to the first dimension. @@ -567,7 +570,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x)) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, first_phase=True)) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] -- cgit v1.2.3 From 158d678f596d7fc304a6ce2f0dc31f8abfe62250 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Thu, 20 Oct 2022 01:08:24 -0500 Subject: clear prompt button now works on both relevant tabs. Device detection stuff will be added later. --- javascript/ui.js | 28 ++++++++++++++++++++++++++++ modules/ui.py | 21 ++++++++++++++++++--- 2 files changed, 46 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/javascript/ui.js b/javascript/ui.js index cfd0dcd3..165383da 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -151,6 +151,34 @@ function ask_for_style_name(_, prompt_text, negative_prompt_text) { return [name_, prompt_text, negative_prompt_text] } +// returns css id for currently selected tab in ui +function selected_tab_id() { + tabs = gradioApp().querySelectorAll('#tabs div.tabitem') + + for(var tab = 0; tab < tabs.length; tab++) { + if (tabs[tab].style.display != "none") return tabs[tab].id + + } + +} + +function trash_prompt(_,_, is_img2img) { + + if(selected_tab_id() == "tab_txt2img") { + pos_prompt = txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); + neg_prompt = txt2img_textarea = gradioApp().querySelector("#txt2img_neg_prompt > label > textarea"); + + pos_prompt.value = "" + neg_prompt.value = "" + } else { + pos_prompt = txt2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); + neg_prompt = txt2img_textarea = gradioApp().querySelector("#img2img_neg_prompt > label > textarea"); + + pos_prompt.value = "" + neg_prompt.value = "" + } +} + opts = {} diff --git a/modules/ui.py b/modules/ui.py index cb9a6c6e..bde546cc 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -424,6 +424,16 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox +# setup button for clearing prompt input boxes on client side of webui +def connect_trash_prompt(dummy_component, button, is_img2img): + + button.click( + fn=lambda: print("Clearing prompt"), + _js="trash_prompt", + inputs=[], + outputs=[], + ) + def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): """ Connects a 'reuse (sub)seed' button's click event so that it copies last used (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength @@ -540,7 +550,7 @@ def create_toprow(is_img2img): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style2.save_to_config = True - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, trash_prompt def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -619,10 +629,11 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\ txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\ - token_button = create_toprow(is_img2img=False) + token_button, trash_prompt_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + connect_trash_prompt(dummy_component, trash_prompt_button, False) with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): @@ -807,7 +818,11 @@ def create_ui(wrap_gradio_gpu_call): token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\ + img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ + token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) + + connect_trash_prompt(dummy_component,trash_prompt_button, True) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.3 From 0ddaf8d2028a7251e8c4ad93551a43b5d4700841 Mon Sep 17 00:00:00 2001 From: captin411 Date: Thu, 20 Oct 2022 00:34:55 -0700 Subject: improve face detection a lot --- modules/textual_inversion/autocrop.py | 99 ++++++++++++++++++++++------------- 1 file changed, 62 insertions(+), 37 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index f858a958..5a551c25 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -8,12 +8,18 @@ GREEN = "#0F0" BLUE = "#00F" RED = "#F00" + def crop_image(im, settings): """ Intelligently crop an image to the subject matter """ if im.height > im.width: im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) - else: + elif im.width > im.height: im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) + else: + im = im.resize((settings.crop_width, settings.crop_height)) + + if im.height == im.width: + return im focus = focal_point(im, settings) @@ -78,13 +84,18 @@ def focal_point(im, settings): [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] ) - if settings.annotate_image: - d = ImageDraw.Draw(im) - - average_point = poi_average(pois, settings, im=im) + average_point = poi_average(pois, settings) if settings.annotate_image: - d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN) + d = ImageDraw.Draw(im) + for f in face_points: + d.rectangle(f.bounding(f.size), outline=RED) + for f in entropy_points: + d.rectangle(f.bounding(30), outline=BLUE) + for poi in pois: + w = max(4, 4 * 0.5 * sqrt(poi.weight)) + d.ellipse(poi.bounding(w), fill=BLUE) + d.ellipse(average_point.bounding(25), outline=GREEN) return average_point @@ -92,22 +103,32 @@ def focal_point(im, settings): def image_face_points(im, settings): np_im = np.array(im) gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml') - - minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side - faces = classifier.detectMultiScale(gray, scaleFactor=1.05, - minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - if len(faces) == 0: - return [] - - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - if settings.annotate_image: - for f in rects: - d = ImageDraw.Draw(im) - d.rectangle(f, outline=RED) - - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects] + tries = [ + [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] + ] + + for t in tries: + # print(t[0]) + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + except: + continue + + if len(faces) > 0: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] + return [] def image_corner_points(im, settings): @@ -132,8 +153,8 @@ def image_corner_points(im, settings): focal_points = [] for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y)) + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4)) return focal_points @@ -167,31 +188,26 @@ def image_entropy_points(im, settings): x_mid = int(crop_best[0] + settings.crop_width/2) y_mid = int(crop_best[1] + settings.crop_height/2) - return [PointOfInterest(x_mid, y_mid)] + return [PointOfInterest(x_mid, y_mid, size=25)] def image_entropy(im): # greyscale image entropy - band = np.asarray(im.convert("1")) + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() -def poi_average(pois, settings, im=None): +def poi_average(pois, settings): weight = 0.0 x = 0.0 y = 0.0 - for pois in pois: - if settings.annotate_image and im is not None: - w = 4 * 0.5 * sqrt(pois.weight) - d = ImageDraw.Draw(im) - d.ellipse([ - pois.x - w, pois.y - w, - pois.x + w, pois.y + w ], fill=BLUE) - weight += pois.weight - x += pois.x * pois.weight - y += pois.y * pois.weight + for poi in pois: + weight += poi.weight + x += poi.x * poi.weight + y += poi.y * poi.weight avg_x = round(x / weight) avg_y = round(y / weight) @@ -199,10 +215,19 @@ def poi_average(pois, settings, im=None): class PointOfInterest: - def __init__(self, x, y, weight=1.0): + def __init__(self, x, y, weight=1.0, size=10): self.x = x self.y = y self.weight = weight + self.size = size + + def bounding(self, size): + return [ + self.x - size//2, + self.y - size//2, + self.x + size//2, + self.y + size//2 + ] class Settings: -- cgit v1.2.3 From f8733ad08be08bafb40f4299785590e11f049e96 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Thu, 20 Oct 2022 11:07:37 +0000 Subject: add linear as a act func (option for doin nothing) --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 913b23b4..716f14b8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1224,7 +1224,7 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") - new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["relu", "leakyrelu"]) + new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"]) with gr.Row(): with gr.Column(scale=3): -- cgit v1.2.3 From 9681419e422515e42444e0174355b760645a846f Mon Sep 17 00:00:00 2001 From: Milly Date: Thu, 20 Oct 2022 16:53:46 +0900 Subject: train: fixed preprocess image ratio --- modules/textual_inversion/preprocess.py | 54 +++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 886cf0c3..2743bdeb 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,5 +1,6 @@ import os from PIL import Image, ImageOps +import math import platform import sys import tqdm @@ -38,6 +39,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) + split_threshold = 0.5 + overlap_ratio = 0.2 assert src != dst, 'same directory specified as source and destination' @@ -78,6 +81,29 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro if process_flip: save_pic_with_caption(ImageOps.mirror(image), index) + def split_pic(image, inverse_xy): + if inverse_xy: + from_w, from_h = image.height, image.width + to_w, to_h = height, width + else: + from_w, from_h = image.width, image.height + to_w, to_h = width, height + h = from_h * to_w // from_w + if inverse_xy: + image = image.resize((h, to_w)) + else: + image = image.resize((to_w, h)) + + split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) + y_step = (h - to_h) / (split_count - 1) + for i in range(split_count): + y = int(y_step * i) + if inverse_xy: + splitted = image.crop((y, 0, y + to_h, to_w)) + else: + splitted = image.crop((0, y, to_w, y + to_h)) + yield splitted + for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] filename = os.path.join(src, imagefile) @@ -89,26 +115,16 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro if shared.state.interrupted: break - ratio = img.height / img.width - is_tall = ratio > 1.35 - is_wide = ratio < 1 / 1.35 - - if process_split and is_tall: - img = img.resize((width, height * img.height // img.width)) - - top = img.crop((0, 0, width, height)) - save_pic(top, index) - - bot = img.crop((0, img.height - height, width, img.height)) - save_pic(bot, index) - elif process_split and is_wide: - img = img.resize((width * img.width // img.height, height)) - - left = img.crop((0, 0, width, height)) - save_pic(left, index) + if img.height > img.width: + ratio = (img.width * height) / (img.height * width) + inverse_xy = False + else: + ratio = (img.height * width) / (img.width * height) + inverse_xy = True - right = img.crop((img.width - width, 0, img.width, height)) - save_pic(right, index) + if process_split and ratio < 1.0 and ratio <= split_threshold: + for splitted in split_pic(img, inverse_xy): + save_pic(splitted, index) else: img = images.resize_image(1, img, width, height) save_pic(img, index) -- cgit v1.2.3 From 85dd62c4c7635b8e21a75f140d093036069e97a1 Mon Sep 17 00:00:00 2001 From: Milly Date: Thu, 20 Oct 2022 22:56:45 +0900 Subject: train: ui: added `Split image threshold` and `Split image overlap ratio` to preprocess --- modules/textual_inversion/preprocess.py | 10 +++++----- modules/ui.py | 16 ++++++++++++++-- 2 files changed, 19 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 2743bdeb..c8df8aa0 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,7 +12,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): try: if process_caption: shared.interrogator.load() @@ -22,7 +22,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio) finally: @@ -34,13 +34,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ -def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): width = process_width height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) - split_threshold = 0.5 - overlap_ratio = 0.2 + split_threshold = max(0.0, min(1.0, split_threshold)) + overlap_ratio = max(0.0, min(0.9, overlap_ratio)) assert src != dst, 'same directory specified as source and destination' diff --git a/modules/ui.py b/modules/ui.py index a2dbd41e..bc7f3330 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1240,10 +1240,14 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') - process_split = gr.Checkbox(label='Split oversized images into two') + process_split = gr.Checkbox(label='Split oversized images') process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") @@ -1251,6 +1255,12 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): run_preprocess = gr.Button(value="Preprocess", variant='primary') + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") with gr.Row(): @@ -1327,7 +1337,9 @@ def create_ui(wrap_gradio_gpu_call): process_flip, process_split, process_caption, - process_caption_deepbooru + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, ], outputs=[ ti_output, -- cgit v1.2.3 From d8acd34f66ab35a91f10d66330bcc95a83bfcac6 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Thu, 20 Oct 2022 23:43:03 +0900 Subject: generalized some functions and option for ignoring first layer --- modules/hypernetworks/hypernetwork.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d617680..3a44b377 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -21,21 +21,27 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - + activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish} + def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): super().__init__() assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - + linears = [] for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - if activation_func == "relu": - linears.append(torch.nn.ReLU()) - if activation_func == "leakyrelu": - linears.append(torch.nn.LeakyReLU()) + # if skip_first_layer because first parameters potentially contain negative values + if i < 1: continue + if activation_func in HypernetworkModule.activation_dict: + linears.append(HypernetworkModule.activation_dict[activation_func]()) + else: + print("Invalid key {} encountered as activation function!".format(activation_func)) + # if use_dropout: + linears.append(torch.nn.Dropout(p=0.3)) if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - if not "ReLU" in layer.__str__(): + if isinstance(layer, torch.nn.Linear): layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() @@ -298,7 +304,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) + # if optimizer == "Adam": or else Adam / AdamW / etc... + optimizer = torch.optim.Adam(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: -- cgit v1.2.3 From a71e0212363979c7cbbb797c9fbd5f8cd03b29d3 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Thu, 20 Oct 2022 23:48:52 +0900 Subject: only linear --- modules/hypernetworks/hypernetwork.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3a44b377..905cbeef 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -35,13 +35,13 @@ class HypernetworkModule(torch.nn.Module): for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # if skip_first_layer because first parameters potentially contain negative values - if i < 1: continue + # if i < 1: continue if activation_func in HypernetworkModule.activation_dict: linears.append(HypernetworkModule.activation_dict[activation_func]()) else: print("Invalid key {} encountered as activation function!".format(activation_func)) # if use_dropout: - linears.append(torch.nn.Dropout(p=0.3)) + # linears.append(torch.nn.Dropout(p=0.3)) if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - if not "ReLU" in layer.__str__(): + if isinstance(layer, torch.nn.Linear): layer_structure += [layer.weight, layer.bias] return layer_structure @@ -304,8 +304,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - # if optimizer == "Adam": or else Adam / AdamW / etc... - optimizer = torch.optim.Adam(weights, lr=scheduler.learn_rate) + # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... + optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: -- cgit v1.2.3 From d07cb46f34b3d9fe7a78b102f899ebef352ea56b Mon Sep 17 00:00:00 2001 From: yfszzx Date: Thu, 20 Oct 2022 23:58:52 +0800 Subject: inspiration pull request --- .gitignore | 3 +- javascript/imageviewer.js | 1 - javascript/inspiration.js | 42 ++++++++++++ modules/inspiration.py | 122 +++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + modules/ui.py | 13 ++-- scripts/create_inspiration_images.py | 45 +++++++++++++ webui.py | 5 ++ 8 files changed, 225 insertions(+), 7 deletions(-) create mode 100644 javascript/inspiration.js create mode 100644 modules/inspiration.py create mode 100644 scripts/create_inspiration_images.py (limited to 'modules') diff --git a/.gitignore b/.gitignore index f9c3357c..434d50b7 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,5 @@ __pycache__ notification.mp3 /SwinIR /textual_inversion -.vscode \ No newline at end of file +.vscode +/inspiration \ No newline at end of file diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 9e380c65..d4ab6984 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -116,7 +116,6 @@ function showGalleryImage() { e.dataset.modded = true; if(e && e.parentElement.tagName == 'DIV'){ e.style.cursor='pointer' - e.style.userSelect='none' e.addEventListener('click', function (evt) { if(!opts.js_modal_lightbox) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) diff --git a/javascript/inspiration.js b/javascript/inspiration.js new file mode 100644 index 00000000..e1c0e114 --- /dev/null +++ b/javascript/inspiration.js @@ -0,0 +1,42 @@ +function public_image_index_in_gallery(item, gallery){ + var index; + var i = 0; + gallery.querySelectorAll("img").forEach(function(e){ + if (e == item) + index = i; + i += 1; + }); + return index; +} + +function inspiration_selected(name, types, name_list){ + var btn = gradioApp().getElementById("inspiration_select_button") + return [gradioApp().getElementById("inspiration_select_button").getAttribute("img-index"), types]; +} +var inspiration_image_click = function(){ + var index = public_image_index_in_gallery(this, gradioApp().getElementById("inspiration_gallery")); + var btn = gradioApp().getElementById("inspiration_select_button") + btn.setAttribute("img-index", index) + setTimeout(function(btn){btn.click();}, 10, btn) +} + +document.addEventListener("DOMContentLoaded", function() { + var mutationObserver = new MutationObserver(function(m){ + var gallery = gradioApp().getElementById("inspiration_gallery") + if (gallery) { + var node = gallery.querySelector(".absolute.backdrop-blur.h-full") + if (node) { + node.style.display = "None"; //parentNode.removeChild(node) + } + + gallery.querySelectorAll('img').forEach(function(e){ + e.onclick = inspiration_image_click + }) + + } + + + }); + mutationObserver.observe( gradioApp(), { childList:true, subtree:true }); + +}); diff --git a/modules/inspiration.py b/modules/inspiration.py new file mode 100644 index 00000000..456bfcb5 --- /dev/null +++ b/modules/inspiration.py @@ -0,0 +1,122 @@ +import os +import random +import gradio +inspiration_path = "inspiration" +inspiration_system_path = os.path.join(inspiration_path, "system") +def read_name_list(file): + if not os.path.exists(file): + return [] + f = open(file, "r") + ret = [] + line = f.readline() + while len(line) > 0: + line = line.rstrip("\n") + ret.append(line) + print(ret) + return ret + +def save_name_list(file, name): + print(file) + f = open(file, "a") + f.write(name + "\n") + +def get_inspiration_images(source, types): + path = os.path.join(inspiration_path , types) + if source == "Favorites": + names = read_name_list(os.path.join(inspiration_system_path, types + "_faverites.txt")) + names = random.sample(names, 25) + elif source == "Abandoned": + names = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) + names = random.sample(names, 25) + elif source == "Exclude abandoned": + abondened = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) + all_names = os.listdir(path) + names = [] + while len(names) < 25: + name = random.choice(all_names) + if name not in abondened: + names.append(name) + else: + names = random.sample(os.listdir(path), 25) + names = random.sample(names, 25) + image_list = [] + for a in names: + image_path = os.path.join(path, a) + images = os.listdir(image_path) + image_list.append(os.path.join(image_path, random.choice(images))) + return image_list, names + +def select_click(index, types, name_list): + name = name_list[int(index)] + path = os.path.join(inspiration_path, types, name) + images = os.listdir(path) + return name, [os.path.join(path, x) for x in images] + +def give_up_click(name, types): + file = os.path.join(inspiration_system_path, types + "_abandoned.txt") + name_list = read_name_list(file) + if name not in name_list: + save_name_list(file, name) + +def collect_click(name, types): + file = os.path.join(inspiration_system_path, types + "_faverites.txt") + print(file) + name_list = read_name_list(file) + print(name_list) + if name not in name_list: + save_name_list(file, name) + +def moveout_click(name, types): + file = os.path.join(inspiration_system_path, types + "_faverites.txt") + name_list = read_name_list(file) + if name not in name_list: + save_name_list(file, name) + +def source_change(source): + if source == "Abandoned" or source == "Favorites": + return gradio.Button.update(visible=True, value=f"Move out {source}") + else: + return gradio.Button.update(visible=False) + +def ui(gr, opts): + with gr.Blocks(analytics_enabled=False) as inspiration: + flag = os.path.exists(inspiration_path) + if flag: + types = os.listdir(inspiration_path) + types = [x for x in types if x != "system"] + flag = len(types) > 0 + if not flag: + os.mkdir(inspiration_path) + gr.HTML(""" +
" + """) + return inspiration + if not os.path.exists(inspiration_system_path): + os.mkdir(inspiration_system_path) + gallery, names = get_inspiration_images("Exclude abandoned", types[0]) + with gr.Row(): + with gr.Column(scale=2): + inspiration_gallery = gr.Gallery(gallery, show_label=False, elem_id="inspiration_gallery").style(grid=5, height='auto') + with gr.Column(scale=1): + types = gr.Dropdown(choices=types, value=types[0], label="Type", visible=len(types) > 1) + with gr.Row(): + source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") + get_inspiration = gr.Button("Get inspiration") + name = gr.Textbox(show_label=False, interactive=False) + with gr.Row(): + send_to_txt2img = gr.Button('to txt2img') + send_to_img2img = gr.Button('to img2img') + style_gallery = gr.Gallery(show_label=False, elem_id="inspiration_style_gallery").style(grid=2, height='auto') + + collect = gr.Button('Collect') + give_up = gr.Button("Don't show any more") + moveout = gr.Button("Move out", visible=False) + with gr.Row(): + select_button = gr.Button('set button', elem_id="inspiration_select_button") + name_list = gr.State(names) + source.change(source_change, inputs=[source], outputs=[moveout]) + get_inspiration.click(get_inspiration_images, inputs=[source, types], outputs=[inspiration_gallery, name_list]) + select_button.click(select_click, _js="inspiration_selected", inputs=[name, types, name_list], outputs=[name, style_gallery]) + give_up.click(give_up_click, inputs=[name, types], outputs=None) + collect.click(collect_click, inputs=[name, types], outputs=None) + return inspiration diff --git a/modules/shared.py b/modules/shared.py index faede821..ae033710 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -78,6 +78,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") cmd_opts = parser.parse_args() restricted_opts = [ diff --git a/modules/ui.py b/modules/ui.py index a2dbd41e..6a0a3c3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -41,7 +41,8 @@ from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.images_history as img_his +import modules.images_history as images_history +import modules.inspiration as inspiration # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -1082,9 +1083,9 @@ def create_ui(wrap_gradio_gpu_call): upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) - + with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers] , value=shared.sd_upscalers[0].name, type="index") with gr.Group(): extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") @@ -1178,7 +1179,8 @@ def create_ui(wrap_gradio_gpu_call): "i2i":img2img_paste_fields } - images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + browser_interface = images_history.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + inspiration_interface = inspiration.ui(gr, opts) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): @@ -1595,7 +1597,8 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (images_history, "History", "images_history"), + (browser_interface, "History", "images_history"), + (inspiration_interface, "Inspiration", "inspiration"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), (settings_interface, "Settings", "settings"), diff --git a/scripts/create_inspiration_images.py b/scripts/create_inspiration_images.py new file mode 100644 index 00000000..6a20def8 --- /dev/null +++ b/scripts/create_inspiration_images.py @@ -0,0 +1,45 @@ +import csv, os, shutil +import modules.scripts as scripts +from modules import processing, shared, sd_samplers, images +from modules.processing import Processed + + +class Script(scripts.Script): + def title(self): + return "Create artists style image" + + def show(self, is_img2img): + return not is_img2img + + def ui(self, is_img2img): + return [] + def show(self, is_img2img): + return not is_img2img + + def run(self, p): #, max_snapshoots_num): + path = os.path.join("style_snapshoot", "artist") + if not os.path.exists(path): + os.makedirs(path) + p.do_not_save_samples = True + p.do_not_save_grid = True + p.negative_prompt = "portrait photo" + f = open('artists.csv') + f_csv = csv.reader(f) + for row in f_csv: + name = row[0] + artist_path = os.path.join(path, name) + if not os.path.exists(artist_path): + os.mkdir(artist_path) + if len(os.listdir(artist_path)) > 0: + continue + print(name) + p.prompt = name + processed = processing.process_images(p) + for img in processed.images: + i = 0 + filename = os.path.join(artist_path, format(0, "03d") + ".jpg") + while os.path.exists(filename): + i += 1 + filename = os.path.join(artist_path, format(i, "03d") + ".jpg") + img.save(filename, quality=70) + return processed diff --git a/webui.py b/webui.py index 177bef74..5923905f 100644 --- a/webui.py +++ b/webui.py @@ -72,6 +72,11 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) def initialize(): + if cmd_opts.ui_debug_mode: + class enmpty(): + name = None + shared.sd_upscalers = [enmpty()] + return modelloader.cleanup_models() modules.sd_models.setup_model() codeformer.setup_model(cmd_opts.codeformer_models_path) -- cgit v1.2.3 From 108be15500aac590b4e00420635d7b61fccfa530 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Fri, 21 Oct 2022 01:00:41 +0900 Subject: fix bugs and optimizations --- modules/hypernetworks/hypernetwork.py | 105 +++++++++++++++++++--------------- 1 file changed, 59 insertions(+), 46 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 905cbeef..893ba110 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # if skip_first_layer because first parameters potentially contain negative values # if i < 1: continue + if add_layer_norm: + linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) if activation_func in HypernetworkModule.activation_dict: linears.append(HypernetworkModule.activation_dict[activation_func]()) else: print("Invalid key {} encountered as activation function!".format(activation_func)) # if use_dropout: # linears.append(torch.nn.Dropout(p=0.3)) - if add_layer_norm: - linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) self.linear = torch.nn.Sequential(*linears) @@ -115,11 +115,24 @@ class Hypernetwork: for k, layers in self.layers.items(): for layer in layers: - layer.train() res += layer.trainables() return res + def eval(self): + for k, layers in self.layers.items(): + for layer in layers: + layer.eval() + for items in self.weights(): + items.requires_grad = False + + def train(self): + for k, layers in self.layers.items(): + for layer in layers: + layer.train() + for items in self.weights(): + items.requires_grad = True + def save(self, filename): state_dict = {} @@ -290,10 +303,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.sd_model.first_stage_model.to(devices.cpu) hypernetwork = shared.loaded_hypernetwork - weights = hypernetwork.weights() - for weight in weights: - weight.requires_grad = True - losses = torch.zeros((32,)) last_saved_file = "" @@ -304,10 +313,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... - optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) + optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) + hypernetwork.train() for i, entries in pbar: hypernetwork.step = i + ititial_step @@ -328,8 +337,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item() - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True) loss.backward() + del loss optimizer.step() mean_loss = losses.mean() if torch.isnan(mean_loss): @@ -346,44 +356,47 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: + torch.cuda.empty_cache() last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + with torch.no_grad(): + hypernetwork.eval() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) - optimizer.zero_grad() - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - ) - - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - - preview_text = p.prompt - - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None - - if unload: - shared.sd_model.cond_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.cpu) - - if image is not None: - shared.state.current_image = image - image.save(last_saved_image) - last_saved_image += f", prompt: {preview_text}" + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + + preview_text = p.prompt + + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images)>0 else None + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + + if image is not None: + shared.state.current_image = image + image.save(last_saved_image) + last_saved_image += f", prompt: {preview_text}" + + hypernetwork.train() shared.state.job_no = hypernetwork.step -- cgit v1.2.3 From f89829ec3a0baceb445451ad98d4fb4323e922aa Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 21 Oct 2022 01:37:11 +0900 Subject: Revert "fix bugs and optimizations" This reverts commit 108be15500aac590b4e00420635d7b61fccfa530. --- modules/hypernetworks/hypernetwork.py | 105 +++++++++++++++------------------- 1 file changed, 46 insertions(+), 59 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 893ba110..905cbeef 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # if skip_first_layer because first parameters potentially contain negative values # if i < 1: continue - if add_layer_norm: - linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) if activation_func in HypernetworkModule.activation_dict: linears.append(HypernetworkModule.activation_dict[activation_func]()) else: print("Invalid key {} encountered as activation function!".format(activation_func)) # if use_dropout: # linears.append(torch.nn.Dropout(p=0.3)) + if add_layer_norm: + linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) self.linear = torch.nn.Sequential(*linears) @@ -115,24 +115,11 @@ class Hypernetwork: for k, layers in self.layers.items(): for layer in layers: + layer.train() res += layer.trainables() return res - def eval(self): - for k, layers in self.layers.items(): - for layer in layers: - layer.eval() - for items in self.weights(): - items.requires_grad = False - - def train(self): - for k, layers in self.layers.items(): - for layer in layers: - layer.train() - for items in self.weights(): - items.requires_grad = True - def save(self, filename): state_dict = {} @@ -303,6 +290,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.sd_model.first_stage_model.to(devices.cpu) hypernetwork = shared.loaded_hypernetwork + weights = hypernetwork.weights() + for weight in weights: + weight.requires_grad = True + losses = torch.zeros((32,)) last_saved_file = "" @@ -313,10 +304,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate) + # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... + optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - hypernetwork.train() for i, entries in pbar: hypernetwork.step = i + ititial_step @@ -337,9 +328,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item() - optimizer.zero_grad(set_to_none=True) + optimizer.zero_grad() loss.backward() - del loss optimizer.step() mean_loss = losses.mean() if torch.isnan(mean_loss): @@ -356,47 +346,44 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: - torch.cuda.empty_cache() last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - with torch.no_grad(): - hypernetwork.eval() - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - ) - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - - preview_text = p.prompt - - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None - - if unload: - shared.sd_model.cond_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.cpu) - - if image is not None: - shared.state.current_image = image - image.save(last_saved_image) - last_saved_image += f", prompt: {preview_text}" - - hypernetwork.train() + optimizer.zero_grad() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + + preview_text = p.prompt + + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images)>0 else None + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + + if image is not None: + shared.state.current_image = image + image.save(last_saved_image) + last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step -- cgit v1.2.3 From 92a17a7a4a13fceb3c3e25a2e854b2a7dd6eb5df Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 20 Oct 2022 09:45:03 -0700 Subject: Made dummy latents smaller. Minor code cleanups --- modules/processing.py | 7 ++++--- modules/sd_samplers.py | 6 ++++-- 2 files changed, 8 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 3caac25e..539cde38 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -557,7 +557,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: # Dummy zero conditioning if we're not using inpainting model. # Still takes up a bit of memory, but no encoder call. - image_conditioning = torch.zeros(x.shape[0], 5, x.shape[-2], x.shape[-1], dtype=x.dtype, device=x.device) + # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. + image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) return image_conditioning @@ -759,8 +760,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) else: self.image_conditioning = torch.zeros( - self.init_latent.shape[0], 5, self.init_latent.shape[-2], self.init_latent.shape[-1], - dtype=self.init_latent.dtype, + self.init_latent.shape[0], 5, 1, 1, + dtype=self.init_latent.dtype, device=self.init_latent.device ) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index c21be26e..cc682593 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -138,7 +138,7 @@ class VanillaStableDiffusionSampler: if self.stop_at is not None and self.step > self.stop_at: raise InterruptedException - # Have to unwrap the inpainting conditioning here to perform pre-preocessing + # Have to unwrap the inpainting conditioning here to perform pre-processing image_conditioning = None if isinstance(cond, dict): image_conditioning = cond["c_concat"][0] @@ -146,7 +146,7 @@ class VanillaStableDiffusionSampler: unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' cond = tensor @@ -165,6 +165,8 @@ class VanillaStableDiffusionSampler: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec + # Wrap the image conditioning back up since the DDIM code can accept the dict directly. + # Note that they need to be lists because it just concatenates them later. if image_conditioning is not None: cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} -- cgit v1.2.3 From d1cb08bfb221cd1b0cfc6078162b4e206ea80a5c Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Thu, 20 Oct 2022 22:49:06 +0300 Subject: fix skip and interrupt for highres. fix option --- modules/processing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index bcb0c32c..6324ca91 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -587,9 +587,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) - - return samples + return self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) or samples class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): -- cgit v1.2.3 From 708c3a7bd8ce68cbe1aa7c268e5a4b1980affc9f Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 20 Oct 2022 13:28:43 -0700 Subject: Added PLMS hijack and made sure to always replace methods --- modules/sd_hijack_inpainting.py | 163 ++++++++++++++++++++++++++++++++++++++-- modules/sd_models.py | 3 +- 2 files changed, 157 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index d4d28d2e..43938071 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -1,16 +1,14 @@ import torch -import numpy as np -from tqdm import tqdm -from einops import rearrange, repeat +from einops import repeat from omegaconf import ListConfig -from types import MethodType - import ldm.models.diffusion.ddpm import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.ddim import DDIMSampler, noise_like # ================================================================================================= @@ -19,7 +17,7 @@ from ldm.models.diffusion.ddim import DDIMSampler, noise_like # https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py # ================================================================================================= @torch.no_grad() -def sample(self, +def sample_ddim(self, S, batch_size, shape, @@ -132,6 +130,153 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0 +# ================================================================================================= +# Monkey patch PLMSSampler methods. +# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes. +# Adapted from: +# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py +# ================================================================================================= +@torch.no_grad() +def sample_plms(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + +@torch.no_grad() +def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t + # ================================================================================================= # Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config. # Adapted from: @@ -175,5 +320,9 @@ def should_hijack_inpainting(checkpoint_info): def do_inpainting_hijack(): ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion + ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim - ldm.models.diffusion.ddim.DDIMSampler.sample = sample \ No newline at end of file + ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim + + ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms + ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms \ No newline at end of file diff --git a/modules/sd_models.py b/modules/sd_models.py index 47836d25..7072db08 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -214,8 +214,6 @@ def load_model(): sd_config = OmegaConf.load(checkpoint_info.config) if should_hijack_inpainting(checkpoint_info): - do_inpainting_hijack() - # Hardcoded config for now... sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" sd_config.model.params.use_ema = False @@ -225,6 +223,7 @@ def load_model(): # Create a "fake" config with a different name so that we know to unload it when switching models. checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) -- cgit v1.2.3 From d23a46ceaa76af2847f11172f32c92665c268b1b Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Thu, 20 Oct 2022 23:49:14 +0300 Subject: Different approach to skip/interrupt with highres fix --- modules/processing.py | 4 +++- modules/sd_samplers.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 6324ca91..bcb0c32c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -587,7 +587,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - return self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) or samples + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) + + return samples class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index b58e810b..7ff77c01 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -196,6 +196,7 @@ class VanillaStableDiffusionSampler: x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) self.init_latent = x + self.last_latent = x self.step = 0 samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) @@ -206,6 +207,7 @@ class VanillaStableDiffusionSampler: self.initialize(p) self.init_latent = None + self.last_latent = x self.step = 0 steps = steps or p.steps @@ -388,6 +390,7 @@ class KDiffusionSampler: extra_params_kwargs['sigmas'] = sigma_sched self.model_wrap_cfg.init_latent = x + self.last_latent = x samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) @@ -414,6 +417,7 @@ class KDiffusionSampler: else: extra_params_kwargs['sigmas'] = sigmas + self.last_latent = x samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples -- cgit v1.2.3 From 49533eed9e3aad19e9868ee140708baec4fd44be Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 20 Oct 2022 16:01:27 -0700 Subject: XY grid correctly re-assignes model when config changes --- modules/sd_models.py | 6 +++--- scripts/xy_grid.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072db08..fea84630 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -204,9 +204,9 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(): +def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack - checkpoint_info = select_checkpoint() + checkpoint_info = checkpoint_info or select_checkpoint() if checkpoint_info.config != shared.cmd_opts.config: print(f"Loading config from: {checkpoint_info.config}") @@ -249,7 +249,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model() + shared.sd_model = load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 5cca168a..eff0c942 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs): if info is None: raise RuntimeError(f"Unknown checkpoint: {x}") modules.sd_models.reload_model_weights(shared.sd_model, info) + p.sd_model = shared.sd_model def confirm_checkpoints(p, xs): -- cgit v1.2.3 From a3b047b7c74dc6ca07f40aee778997fc1889d72f Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Thu, 20 Oct 2022 19:28:58 -0500 Subject: add settings option to toggle button visibility --- javascript/ui.js | 1 - modules/shared.py | 1 + modules/ui.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/javascript/ui.js b/javascript/ui.js index 39eae1f7..f19af550 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -163,7 +163,6 @@ function selected_tab_id() { } function trash_prompt(_,_, is_img2img) { -//txt2img_token_button if(!confirm("Delete prompt?")) return false diff --git a/modules/shared.py b/modules/shared.py index faede821..7e9c2696 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -300,6 +300,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + "trash_prompt_visible": OptionInfo(True, "Show trash prompt button"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index bde546cc..13c0b4ca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -509,7 +509,7 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt") + trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt", visible=opts.trash_prompt_visible) token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") -- cgit v1.2.3 From 45872181902ada06267e2de601586d512cf5df1a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 09:00:39 +0300 Subject: updated readme and some small stylistic changes to code --- README.md | 1 + modules/processing.py | 14 ++++++-------- modules/sd_hijack_inpainting.py | 3 +++ 3 files changed, 10 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/README.md b/README.md index 859a91b6..a98bb00b 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) +- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. diff --git a/modules/processing.py b/modules/processing.py index 539cde38..21786968 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -540,11 +540,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - - def create_dummy_mask(self, x, first_phase: bool = False): + def create_dummy_mask(self, x, width=None, height=None): if self.sampler.conditioning_key in {'hybrid', 'concat'}: - height = self.firstphase_height if first_phase else self.height - width = self.firstphase_width if first_phase else self.width + height = height or self.height + width = width or self.width # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) @@ -571,7 +570,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, first_phase=True)) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height)) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] @@ -634,6 +633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.inpainting_mask_invert = inpainting_mask_invert self.mask = None self.nmask = None + self.image_conditioning = None def init(self, all_prompts, all_seeds, all_subseeds): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) @@ -735,9 +735,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - conditioning_key = self.sampler.conditioning_key - - if conditioning_key in {'hybrid', 'concat'}: + if self.sampler.conditioning_key in {'hybrid', 'concat'}: if self.image_mask is not None: conditioning_mask = np.array(self.image_mask.convert("L")) conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 43938071..fd92a335 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -301,6 +301,7 @@ def get_unconditional_conditioning(self, batch_size, null_label=None): c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) return c + class LatentInpaintDiffusion(LatentDiffusion): def __init__( self, @@ -314,9 +315,11 @@ class LatentInpaintDiffusion(LatentDiffusion): assert self.masked_image_key in concat_keys self.concat_keys = concat_keys + def should_hijack_inpainting(checkpoint_info): return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml") + def do_inpainting_hijack(): ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion -- cgit v1.2.3 From 74088c2a06a975092806362aede22f82716cb011 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 20 Oct 2022 08:18:02 +0300 Subject: allow float sizes for hypernet's layer_structure --- modules/hypernetworks/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 08f75f15..e0741d08 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -15,7 +15,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm assert not os.path.exists(fn), f"file {fn} already exists" if type(layer_structure) == str: - layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure))) + layer_structure = [float(x.strip()) for x in layer_structure.split(",")] hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( name=name, -- cgit v1.2.3 From 60872c5b404114336f9ca0c671ba88fa4a8201c9 Mon Sep 17 00:00:00 2001 From: winterspringsummer Date: Thu, 20 Oct 2022 19:10:32 +0900 Subject: Fixed path issue while extras batch processing --- modules/extras.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index b853fa5b..f9796624 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -118,10 +118,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ while len(cached_images) > 2: del cached_images[next(iter(cached_images.keys()))] - - images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, - forced_filename=image_name if opts.use_original_name_batch else None) + + if opts.use_original_name_batch and image_name != None: + basename = os.path.splitext(os.path.basename(image_name))[0] + else: + basename = '' + + images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) if opts.enable_pnginfo: image.info = existing_pnginfo -- cgit v1.2.3 From fb5a8cf0d9ed027ea3aa2e5422c946d8e6e72efe Mon Sep 17 00:00:00 2001 From: winterspringsummer Date: Thu, 20 Oct 2022 21:31:29 +0900 Subject: Added try except to extras batch from directory --- modules/extras.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index f9796624..0d817cf9 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -41,7 +41,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ return outputs, "Please select an input directory.", '' image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] for img in image_list: - image = Image.open(img) + try: + image = Image.open(img) + except Exception: + continue imageArr.append(image) imageNameArr.append(img) else: @@ -122,10 +125,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ if opts.use_original_name_batch and image_name != None: basename = os.path.splitext(os.path.basename(image_name))[0] else: - basename = '' + basename = None - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) + images.save_image(image, path=outpath, basename='', seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=basename) if opts.enable_pnginfo: image.info = existing_pnginfo -- cgit v1.2.3 From a13c3bed3cec27afe3c015d3d62db36e25b10d1f Mon Sep 17 00:00:00 2001 From: winterspringsummer Date: Thu, 20 Oct 2022 21:43:27 +0900 Subject: Fixed path issue while extras batch processing --- modules/extras.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 0d817cf9..ac85142c 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -125,10 +125,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ if opts.use_original_name_batch and image_name != None: basename = os.path.splitext(os.path.basename(image_name))[0] else: - basename = None + basename = '' - images.save_image(image, path=outpath, basename='', seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=basename) + images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) if opts.enable_pnginfo: image.info = existing_pnginfo -- cgit v1.2.3 From 9d71eef02e7395e179b8d5e61e6d91ddd8928d2e Mon Sep 17 00:00:00 2001 From: winterspringsummer Date: Fri, 21 Oct 2022 09:23:13 +0900 Subject: sort file list in alphabetical ordering in extras --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index ac85142c..22c5a1c1 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -39,7 +39,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ if input_dir == '': return outputs, "Please select an input directory.", '' - image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] + image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)] for img in image_list: try: image = Image.open(img) -- cgit v1.2.3 From c23f666dba2b484d521d2dc4be91cf9e09312647 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 09:47:43 +0300 Subject: a more strict check for activation type and a more reasonable check for type of layer in hypernets --- modules/hypernetworks/hypernetwork.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d617680..84e7e350 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -32,10 +32,16 @@ class HypernetworkModule(torch.nn.Module): linears = [] for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + if activation_func == "relu": linears.append(torch.nn.ReLU()) - if activation_func == "leakyrelu": + elif activation_func == "leakyrelu": linears.append(torch.nn.LeakyReLU()) + elif activation_func == 'linear' or activation_func is None: + pass + else: + raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') + if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - if not "ReLU" in layer.__str__(): + if type(layer) == torch.nn.Linear: layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() @@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - if not "ReLU" in layer.__str__(): + if type(layer) == torch.nn.Linear: layer_structure += [layer.weight, layer.bias] return layer_structure -- cgit v1.2.3 From 7157e5d064741fa57ca81a2c6432a651f21ee82f Mon Sep 17 00:00:00 2001 From: Patryk Wychowaniec Date: Thu, 20 Oct 2022 19:22:59 +0200 Subject: interrogate: Fix CLIP-interrogation on CPU Currently, trying to perform CLIP interrogation on a CPU fails, saying: ``` RuntimeError: "slow_conv2d_cpu" not implemented for 'Half' ``` This merge request fixes this issue by detecting whether the target device is CPU and, if so, force-enabling `--no-half` and passing `device="cpu"` to `clip.load()` (which then does some extra tricks to ensure it works correctly on CPU). --- modules/interrogate.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index 64b91eb4..65b05d34 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -28,9 +28,11 @@ class InterrogateModels: clip_preprocess = None categories = None dtype = None + running_on_cpu = None def __init__(self, content_dir): self.categories = [] + self.running_on_cpu = devices.device_interrogate == torch.device("cpu") if os.path.exists(content_dir): for filename in os.listdir(content_dir): @@ -53,7 +55,11 @@ class InterrogateModels: def load_clip_model(self): import clip - model, preprocess = clip.load(clip_model_name) + if self.running_on_cpu: + model, preprocess = clip.load(clip_model_name, device="cpu") + else: + model, preprocess = clip.load(clip_model_name) + model.eval() model = model.to(devices.device_interrogate) @@ -62,14 +68,14 @@ class InterrogateModels: def load(self): if self.blip_model is None: self.blip_model = self.load_blip_model() - if not shared.cmd_opts.no_half: + if not shared.cmd_opts.no_half and not self.running_on_cpu: self.blip_model = self.blip_model.half() self.blip_model = self.blip_model.to(devices.device_interrogate) if self.clip_model is None: self.clip_model, self.clip_preprocess = self.load_clip_model() - if not shared.cmd_opts.no_half: + if not shared.cmd_opts.no_half and not self.running_on_cpu: self.clip_model = self.clip_model.half() self.clip_model = self.clip_model.to(devices.device_interrogate) -- cgit v1.2.3 From b69c37d25e4ffc56e8f8c247fa2c38b4648cefb7 Mon Sep 17 00:00:00 2001 From: guaneec Date: Thu, 20 Oct 2022 22:21:12 +0800 Subject: Allow datasets with only 1 image in TI --- modules/textual_inversion/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 23bb4b6a..5b1c5002 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -83,7 +83,7 @@ class PersonalizedBase(Dataset): self.dataset.append(entry) - assert len(self.dataset) > 1, "No images have been found in the dataset." + assert len(self.dataset) > 0, "No images have been found in the dataset." self.length = len(self.dataset) * repeats // batch_size self.initial_indexes = np.arange(len(self.dataset)) @@ -91,7 +91,7 @@ class PersonalizedBase(Dataset): self.shuffle() def shuffle(self): - self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()] def create_text(self, filename_text): text = random.choice(self.lines) -- cgit v1.2.3 From 5245c7a4935f67b677da0f5a1fc2b74c074aa0e2 Mon Sep 17 00:00:00 2001 From: timntorres Date: Wed, 19 Oct 2022 12:21:32 -0700 Subject: Issue #2921-Give PNG info to Hypernet previews. --- modules/hypernetworks/hypernetwork.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 84e7e350..68c8f26d 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -256,6 +256,9 @@ def stack_conds(conds): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + # images is required here to give training previews their infotext. Importing this at the very top causes a circular dependency. + from modules import images + assert hypernetwork_name, 'hypernetwork not selected' path = shared.hypernetworks.get(hypernetwork_name, None) @@ -298,6 +301,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log last_saved_file = "" last_saved_image = "" + forced_filename = "" ititial_step = hypernetwork.step or 0 if ititial_step > steps: @@ -345,7 +349,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: - last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + forced_filename = f'{hypernetwork_name}-{hypernetwork.step}' + last_saved_image = os.path.join(images_dir, forced_filename) optimizer.zero_grad() shared.sd_model.cond_stage_model.to(devices.device) @@ -381,7 +386,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if image is not None: shared.state.current_image = image - image.save(last_saved_image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step -- cgit v1.2.3 From 6014fb8afbe05c8d02fffe7a36a2e48128713bd2 Mon Sep 17 00:00:00 2001 From: timntorres Date: Wed, 19 Oct 2022 12:22:23 -0700 Subject: Do nothing if image file already exists. --- modules/images.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index b9589563..550e53ae 100644 --- a/modules/images.py +++ b/modules/images.py @@ -416,7 +416,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /') path = os.path.join(path, dirname) - os.makedirs(path, exist_ok=True) + try: + os.makedirs(path, exist_ok=True) + except FileExistsError: + # If the file already exists, continue and allow said file to be overwritten. + pass if forced_filename is None: basecount = get_next_sequence_number(path, basename) -- cgit v1.2.3 From 4ff274e1e35bb642687253ce744d2cfa738ab293 Mon Sep 17 00:00:00 2001 From: timntorres Date: Wed, 19 Oct 2022 12:32:22 -0700 Subject: Revise comments. --- modules/hypernetworks/hypernetwork.py | 2 +- modules/images.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 68c8f26d..3f96361c 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -256,7 +256,7 @@ def stack_conds(conds): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # images is required here to give training previews their infotext. Importing this at the very top causes a circular dependency. + # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images assert hypernetwork_name, 'hypernetwork not selected' diff --git a/modules/images.py b/modules/images.py index 550e53ae..b8834e3c 100644 --- a/modules/images.py +++ b/modules/images.py @@ -419,7 +419,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i try: os.makedirs(path, exist_ok=True) except FileExistsError: - # If the file already exists, continue and allow said file to be overwritten. + # If the file already exists, allow said file to be overwritten. pass if forced_filename is None: -- cgit v1.2.3 From 2273e752fb3e578f1047f6d38b96330b07bf61a9 Mon Sep 17 00:00:00 2001 From: timntorres Date: Wed, 19 Oct 2022 14:23:48 -0700 Subject: Remove redundant try/except. --- modules/images.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index b8834e3c..b9589563 100644 --- a/modules/images.py +++ b/modules/images.py @@ -416,11 +416,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /') path = os.path.join(path, dirname) - try: - os.makedirs(path, exist_ok=True) - except FileExistsError: - # If the file already exists, allow said file to be overwritten. - pass + os.makedirs(path, exist_ok=True) if forced_filename is None: basecount = get_next_sequence_number(path, basename) -- cgit v1.2.3 From 03a1e288c4973dd2dff57a97469b40f146b6fccf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 10:13:24 +0300 Subject: turns out LayerNorm also has weight and bias and needs to be pre-multiplied and trained for hypernets --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3274a802..b1a5d0c7 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - if type(layer) == torch.nn.Linear: + if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() @@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - if type(layer) == torch.nn.Linear: + if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: layer_structure += [layer.weight, layer.bias] return layer_structure -- cgit v1.2.3 From bf30673f5132c8f28357b31224c54331e788d3e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 10:19:25 +0300 Subject: Fix Hypernet infotext string split bug for PR #3283 --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 21786968..d1deffa9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]), + "Hypernet": (None if shared.loaded_hypernetwork is None else os.path.splitext(os.path.basename(shared.loaded_hypernetwork.filename))[0]), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), -- cgit v1.2.3 From df5706409386cc2e88718bd9101045587c39f8bb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 16:10:51 +0300 Subject: do not load aesthetic clip model until it's needed add refresh button for aesthetic embeddings add aesthetic params to images' infotext --- modules/aesthetic_clip.py | 40 +++++++++++++++++++---- modules/generation_parameters_copypaste.py | 18 +++++++++-- modules/img2img.py | 5 +-- modules/processing.py | 4 +-- modules/sd_models.py | 3 -- modules/txt2img.py | 4 +-- modules/ui.py | 52 ++++++++++++++++++++---------- style.css | 2 +- 8 files changed, 89 insertions(+), 39 deletions(-) (limited to 'modules') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index 34efa931..8c828541 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -40,6 +40,8 @@ def iter_to_batched(iterable, n=1): def create_ui(): + import modules.ui + with gr.Group(): with gr.Accordion("Open for Clip Aesthetic!", open=False): with gr.Row(): @@ -55,6 +57,8 @@ def create_ui(): label="Aesthetic imgs embedding", value="None") + modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings") + with gr.Row(): aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", @@ -66,11 +70,21 @@ def create_ui(): return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative +aesthetic_clip_model = None + + +def aesthetic_clip(): + global aesthetic_clip_model + + if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path: + aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path) + aesthetic_clip_model.cpu() + + return aesthetic_clip_model + + def generate_imgs_embd(name, folder, batch_size): - # clipModel = CLIPModel.from_pretrained( - # shared.sd_model.cond_stage_model.clipModel.name_or_path - # ) - model = shared.clip_model.to(device) + model = aesthetic_clip().to(device) processor = CLIPProcessor.from_pretrained(model.name_or_path) with torch.no_grad(): @@ -91,7 +105,7 @@ def generate_imgs_embd(name, folder, batch_size): path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") torch.save(embs, path) - model = model.cpu() + model.cpu() del processor del embs gc.collect() @@ -132,7 +146,7 @@ class AestheticCLIP: self.image_embs = None self.load_image_embs(None) - def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, + def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, aesthetic_slerp=True, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False): @@ -145,6 +159,18 @@ class AestheticCLIP: self.aesthetic_steps = aesthetic_steps self.load_image_embs(image_embs_name) + if self.image_embs_name is not None: + p.extra_generation_params.update({ + "Aesthetic LR": aesthetic_lr, + "Aesthetic weight": aesthetic_weight, + "Aesthetic steps": aesthetic_steps, + "Aesthetic embedding": self.image_embs_name, + "Aesthetic slerp": aesthetic_slerp, + "Aesthetic text": aesthetic_imgs_text, + "Aesthetic text negative": aesthetic_text_negative, + "Aesthetic slerp angle": aesthetic_slerp_angle, + }) + def set_skip(self, skip): self.skip = skip @@ -168,7 +194,7 @@ class AestheticCLIP: tokens = torch.asarray(remade_batch_tokens).to(device) - model = copy.deepcopy(shared.clip_model).to(device) + model = copy.deepcopy(aesthetic_clip()).to(device) model.requires_grad_(True) if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: text_embs_2 = model.get_text_features( diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 0f041449..f73647da 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -4,13 +4,22 @@ import gradio as gr from modules.shared import script_path from modules import shared -re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" +re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") type_of_gr_update = type(gr.update()) +def quote(text): + if ',' not in str(text): + return text + + text = str(text) + text = text.replace('\\', '\\\\') + text = text.replace('"', '\\"') + return f'"{text}"' + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` @@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None): else: try: valtype = type(output.value) - val = valtype(v) + + if valtype == bool and v == "False": + val = False + else: + val = valtype(v) + res.append(gr.update(value=val)) except Exception: res.append(gr.update()) diff --git a/modules/img2img.py b/modules/img2img.py index bc7c66bc..eea5199b 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -109,10 +109,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpainting_mask_invert=inpainting_mask_invert, ) - shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), - aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative) + shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index d1deffa9..f0852cd5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -318,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params.update(p.extra_generation_params) - generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" diff --git a/modules/sd_models.py b/modules/sd_models.py index 05a1df28..b1c91b0d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -234,9 +234,6 @@ def load_model(checkpoint_info=None): sd_hijack.model_hijack.hijack(sd_model) - if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path: - shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path) - sd_model.eval() print(f"Model loaded.") diff --git a/modules/txt2img.py b/modules/txt2img.py index 32ed1d8d..1761cfa2 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -36,9 +36,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: firstphase_height=firstphase_height if enable_hr else None, ) - shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), - aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, - aesthetic_text_negative) + shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/ui.py b/modules/ui.py index 381ca925..0d020de6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -597,27 +597,29 @@ def apply_setting(key, value): return value -def create_ui(wrap_gradio_gpu_call): - import modules.img2img - import modules.txt2img +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args - def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args + for k, v in args.items(): + setattr(refresh_component, k, v) - for k, v in args.items(): - setattr(refresh_component, k, v) + return gr.update(**(args or {})) - return gr.update(**(args or {})) + refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_ui(wrap_gradio_gpu_call): + import modules.img2img + import modules.txt2img - refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn = refresh, - inputs = [], - outputs = [refresh_component] - ) - return refresh_button with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) @@ -802,6 +804,14 @@ def create_ui(wrap_gradio_gpu_call): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), + (aesthetic_lr, "Aesthetic LR"), + (aesthetic_weight, "Aesthetic weight"), + (aesthetic_steps, "Aesthetic steps"), + (aesthetic_imgs, "Aesthetic embedding"), + (aesthetic_slerp, "Aesthetic slerp"), + (aesthetic_imgs_text, "Aesthetic text"), + (aesthetic_text_negative, "Aesthetic text negative"), + (aesthetic_slerp_angle, "Aesthetic slerp angle"), ] txt2img_preview_params = [ @@ -1077,6 +1087,14 @@ def create_ui(wrap_gradio_gpu_call): (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), (denoising_strength, "Denoising strength"), + (aesthetic_lr_im, "Aesthetic LR"), + (aesthetic_weight_im, "Aesthetic weight"), + (aesthetic_steps_im, "Aesthetic steps"), + (aesthetic_imgs_im, "Aesthetic embedding"), + (aesthetic_slerp_im, "Aesthetic slerp"), + (aesthetic_imgs_text_im, "Aesthetic text"), + (aesthetic_text_negative_im, "Aesthetic text negative"), + (aesthetic_slerp_angle_im, "Aesthetic slerp angle"), ] token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) diff --git a/style.css b/style.css index 26ae36a5..5d2bacc9 100644 --- a/style.css +++ b/style.css @@ -477,7 +477,7 @@ input[type="range"]{ padding: 0; } -#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{ +#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_aesthetic_embeddings{ max-width: 2.5em; min-width: 2.5em; height: 2.4em; -- cgit v1.2.3 From 9286fe53de2eef91f13cc3ad5938ddf67ecc8413 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 16:38:06 +0300 Subject: make aestetic embedding ciompatible with prompts longer than 75 tokens --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 36198a3c..1f8587d1 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -332,8 +332,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers.append([1.0] * 75) z1 = self.process_tokens(tokens, multipliers) + z1 = shared.aesthetic_clip(z1, remade_batch_tokens) z = z1 if z is None else torch.cat((z, z1), axis=-2) - z = shared.aesthetic_clip(z, remade_batch_tokens) remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers -- cgit v1.2.3 From d0ea471b0cdaede163c6e7f6fae8535f5c3cd226 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 21 Oct 2022 14:04:41 +0100 Subject: Use opts in textual_inversion image_embedding.py for dynamic fonts --- modules/textual_inversion/image_embedding.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index 898ce3b3..c50b1e7b 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -5,6 +5,7 @@ import zlib from PIL import Image, PngImagePlugin, ImageDraw, ImageFont from fonts.ttf import Roboto import torch +from modules.shared import opts class EmbeddingEncoder(json.JSONEncoder): -- cgit v1.2.3 From 306e2ff6ab8f4c7e94ab55f4f08ab8f94d73d287 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 21 Oct 2022 14:47:21 +0100 Subject: Update image_embedding.py --- modules/textual_inversion/image_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index c50b1e7b..ea653806 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -134,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t from math import cos image = srcimage.copy() - + fontsize = 32 if textfont is None: try: textfont = ImageFont.truetype(opts.font or Roboto, fontsize) @@ -151,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) draw = ImageDraw.Draw(image) - fontsize = 32 + font = ImageFont.truetype(textfont, fontsize) padding = 10 -- cgit v1.2.3 From 51e3dc9ccad157d7161b697a246e26c868d46a7c Mon Sep 17 00:00:00 2001 From: timntorres Date: Fri, 21 Oct 2022 02:11:12 -0700 Subject: Sanitize hypernet name input. --- modules/hypernetworks/ui.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 266f04f6..e6f50a1f 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -11,6 +11,9 @@ from modules.hypernetworks import hypernetwork def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None): + # Remove illegal characters from name. + name = "".join( x for x in name if (x.isalnum() or x in "._- ")) + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") if not overwrite_old: assert not os.path.exists(fn), f"file {fn} already exists" -- cgit v1.2.3 From 19818f023cfafc472c6c241cab0b72896a168481 Mon Sep 17 00:00:00 2001 From: timntorres Date: Fri, 21 Oct 2022 02:14:02 -0700 Subject: Match hypernet name with filename in all cases. --- modules/hypernetworks/hypernetwork.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b1a5d0c7..6d392be4 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -340,7 +340,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log pbar.set_description(f"loss: {mean_loss:.7f}") if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') + temp = hypernetwork.name + # Before saving, change name to match current checkpoint. + hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') hypernetwork.save(last_saved_file) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { @@ -405,6 +408,9 @@ Last saved image: {html.escape(last_saved_image)}
hypernetwork.sd_checkpoint = checkpoint.hash hypernetwork.sd_checkpoint_name = checkpoint.model_name + # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention). + hypernetwork.name = hypernetwork_name + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt') hypernetwork.save(filename) return hypernetwork, filename -- cgit v1.2.3 From fccad18a59e3c2c33fefbbb1763c6a87a3a68eba Mon Sep 17 00:00:00 2001 From: timntorres Date: Fri, 21 Oct 2022 02:17:26 -0700 Subject: Refer to Hypernet's name, sensibly, by its name variable. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index f0852cd5..ff1ec4c9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else os.path.splitext(os.path.basename(shared.loaded_hypernetwork.filename))[0]), + "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), -- cgit v1.2.3 From 272fa527bbe93143668ffc16838107b7dca35b40 Mon Sep 17 00:00:00 2001 From: timntorres Date: Fri, 21 Oct 2022 02:41:55 -0700 Subject: Remove unused variable. --- modules/hypernetworks/hypernetwork.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 6d392be4..47d91ea5 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -340,7 +340,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log pbar.set_description(f"loss: {mean_loss:.7f}") if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: - temp = hypernetwork.name # Before saving, change name to match current checkpoint. hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}' last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') -- cgit v1.2.3 From 02e4d4694dd9254a6ca9f05c2eb7b01ea508abc7 Mon Sep 17 00:00:00 2001 From: Rcmcpe Date: Fri, 21 Oct 2022 15:53:35 +0800 Subject: Change option description of unload_models_when_training --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 5c675b80..41d7f08e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -266,7 +266,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('training', "Training"), { - "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), + "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), -- cgit v1.2.3 From 704036ff07b71bf86cadcbbff2bcfeebdd1ed3a6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 17:11:42 +0300 Subject: make aspect ratio overlay work regardless of selected localization --- javascript/aspectRatioOverlay.js | 36 +++++++++++++++++------------------- javascript/dragdrop.js | 2 +- modules/ui.py | 4 ++-- 3 files changed, 20 insertions(+), 22 deletions(-) (limited to 'modules') diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js index 96f1c00d..d3ca2781 100644 --- a/javascript/aspectRatioOverlay.js +++ b/javascript/aspectRatioOverlay.js @@ -3,12 +3,12 @@ let currentWidth = null; let currentHeight = null; let arFrameTimeout = setTimeout(function(){},0); -function dimensionChange(e,dimname){ +function dimensionChange(e, is_width, is_height){ - if(dimname == 'Width'){ + if(is_width){ currentWidth = e.target.value*1.0 } - if(dimname == 'Height'){ + if(is_height){ currentHeight = e.target.value*1.0 } @@ -98,22 +98,20 @@ onUiUpdate(function(){ var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) if(inImg2img){ let inputs = gradioApp().querySelectorAll('input'); - inputs.forEach(function(e){ - let parentLabel = e.parentElement.querySelector('label') - if(parentLabel && parentLabel.innerText){ - if(!e.classList.contains('scrollwatch')){ - if(parentLabel.innerText == 'Width' || parentLabel.innerText == 'Height'){ - e.addEventListener('input', function(e){dimensionChange(e,parentLabel.innerText)} ) - e.classList.add('scrollwatch') - } - if(parentLabel.innerText == 'Width'){ - currentWidth = e.value*1.0 - } - if(parentLabel.innerText == 'Height'){ - currentHeight = e.value*1.0 - } - } - } + inputs.forEach(function(e){ + var is_width = e.parentElement.id == "img2img_width" + var is_height = e.parentElement.id == "img2img_height" + + if((is_width || is_height) && !e.classList.contains('scrollwatch')){ + e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} ) + e.classList.add('scrollwatch') + } + if(is_width){ + currentWidth = e.value*1.0 + } + if(is_height){ + currentHeight = e.value*1.0 + } }) } }); diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index 070cf255..3ed1cb3c 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) { window.document.addEventListener('dragover', e => { const target = e.composedPath()[0]; const imgWrap = target.closest('[data-testid="image"]'); - if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) { + if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { return; } e.stopPropagation(); diff --git a/modules/ui.py b/modules/ui.py index 0d020de6..85f95792 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -879,8 +879,8 @@ def create_ui(wrap_gradio_gpu_call): sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") with gr.Group(): - width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height") with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) -- cgit v1.2.3 From ac0aa2b18efeeb9220a5994c8dd54c7cdda7cc40 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 17:35:51 +0300 Subject: loading SD VAE, see PR #3303 --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index b1c91b0d..d99dbce8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -155,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + + def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash @@ -186,7 +189,7 @@ def load_model_weights(model, checkpoint_info): if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.3 From f49c08ea566385db339c6628f65c3a121033f67c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 18:46:02 +0300 Subject: prevent error spam when processing images without txt files for captions --- modules/textual_inversion/preprocess.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 17e4ddc1..33eaddb6 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -122,11 +122,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre continue existing_caption = None - - try: - existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read() - except Exception as e: - print(e) + existing_caption_filename = os.path.splitext(filename)[0] + '.txt' + if os.path.exists(existing_caption_filename): + with open(existing_caption_filename, 'r', encoding="utf8") as file: + existing_caption = file.read() if shared.state.interrupted: break -- cgit v1.2.3 From bb0f1a2cdae3410a41d06ae878f56e29b8154c41 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sat, 22 Oct 2022 01:23:00 +0800 Subject: inspiration finished --- javascript/inspiration.js | 27 ++++--- modules/inspiration.py | 192 ++++++++++++++++++++++++++++++---------------- modules/shared.py | 6 ++ modules/ui.py | 2 +- webui.py | 3 +- 5 files changed, 151 insertions(+), 79 deletions(-) (limited to 'modules') diff --git a/javascript/inspiration.js b/javascript/inspiration.js index e1c0e114..791a80c9 100644 --- a/javascript/inspiration.js +++ b/javascript/inspiration.js @@ -1,25 +1,31 @@ function public_image_index_in_gallery(item, gallery){ + var imgs = gallery.querySelectorAll("img.h-full") var index; var i = 0; - gallery.querySelectorAll("img").forEach(function(e){ + imgs.forEach(function(e){ if (e == item) index = i; i += 1; }); + var num = imgs.length / 2 + index = (index < num) ? index : (index - num) return index; } -function inspiration_selected(name, types, name_list){ +function inspiration_selected(name, name_list){ var btn = gradioApp().getElementById("inspiration_select_button") - return [gradioApp().getElementById("inspiration_select_button").getAttribute("img-index"), types]; -} + return [gradioApp().getElementById("inspiration_select_button").getAttribute("img-index")]; +} +function inspiration_click_get_button(){ + gradioApp().getElementById("inspiration_get_button").click(); +} var inspiration_image_click = function(){ var index = public_image_index_in_gallery(this, gradioApp().getElementById("inspiration_gallery")); - var btn = gradioApp().getElementById("inspiration_select_button") - btn.setAttribute("img-index", index) - setTimeout(function(btn){btn.click();}, 10, btn) + var btn = gradioApp().getElementById("inspiration_select_button"); + btn.setAttribute("img-index", index); + setTimeout(function(btn){btn.click();}, 10, btn); } - + document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ var gallery = gradioApp().getElementById("inspiration_gallery") @@ -27,11 +33,10 @@ document.addEventListener("DOMContentLoaded", function() { var node = gallery.querySelector(".absolute.backdrop-blur.h-full") if (node) { node.style.display = "None"; //parentNode.removeChild(node) - } - + } gallery.querySelectorAll('img').forEach(function(e){ e.onclick = inspiration_image_click - }) + }); } diff --git a/modules/inspiration.py b/modules/inspiration.py index 456bfcb5..f72ebf3a 100644 --- a/modules/inspiration.py +++ b/modules/inspiration.py @@ -1,122 +1,182 @@ import os import random -import gradio -inspiration_path = "inspiration" -inspiration_system_path = os.path.join(inspiration_path, "system") -def read_name_list(file): +import gradio +from modules.shared import opts +inspiration_system_path = os.path.join(opts.inspiration_dir, "system") +def read_name_list(file, types=None, keyword=None): if not os.path.exists(file): return [] - f = open(file, "r") ret = [] + f = open(file, "r") line = f.readline() while len(line) > 0: line = line.rstrip("\n") - ret.append(line) - print(ret) + if types is not None: + dirname = os.path.split(line) + if dirname[0] in types and keyword in dirname[1]: + ret.append(line) + else: + ret.append(line) + line = f.readline() return ret def save_name_list(file, name): - print(file) - f = open(file, "a") - f.write(name + "\n") + with open(file, "a") as f: + f.write(name + "\n") -def get_inspiration_images(source, types): - path = os.path.join(inspiration_path , types) +def get_types_list(): + files = os.listdir(opts.inspiration_dir) + types = [] + for x in files: + path = os.path.join(opts.inspiration_dir, x) + if x[0] == ".": + continue + if not os.path.isdir(path): + continue + if path == inspiration_system_path: + continue + types.append(x) + return types + +def get_inspiration_images(source, types, keyword): + get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num) if source == "Favorites": - names = read_name_list(os.path.join(inspiration_system_path, types + "_faverites.txt")) - names = random.sample(names, 25) + names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword) + names = random.sample(names, get_num) if len(names) > get_num else names elif source == "Abandoned": - names = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) - names = random.sample(names, 25) - elif source == "Exclude abandoned": - abondened = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) - all_names = os.listdir(path) - names = [] - while len(names) < 25: - name = random.choice(all_names) - if name not in abondened: - names.append(name) + names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) + print(names) + names = random.sample(names, get_num) if len(names) > get_num else names + elif source == "Exclude abandoned": + abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) + all_names = [] + for tp in types: + name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) + all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + + if len(all_names) > get_num: + names = [] + while len(names) < get_num: + name = random.choice(all_names) + if name not in abandoned: + names.append(name) + else: + names = all_names else: - names = random.sample(os.listdir(path), 25) - names = random.sample(names, 25) + all_names = [] + for tp in types: + name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) + all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names image_list = [] for a in names: - image_path = os.path.join(path, a) + image_path = os.path.join(opts.inspiration_dir, a) images = os.listdir(image_path) - image_list.append(os.path.join(image_path, random.choice(images))) - return image_list, names + image_list.append((os.path.join(image_path, random.choice(images)), a)) + return image_list, names, "" -def select_click(index, types, name_list): +def select_click(index, name_list): name = name_list[int(index)] - path = os.path.join(inspiration_path, types, name) + path = os.path.join(opts.inspiration_dir, name) images = os.listdir(path) - return name, [os.path.join(path, x) for x in images] + return name, [os.path.join(path, x) for x in images], "" -def give_up_click(name, types): - file = os.path.join(inspiration_system_path, types + "_abandoned.txt") +def give_up_click(name): + file = os.path.join(inspiration_system_path, "abandoned.txt") name_list = read_name_list(file) if name not in name_list: save_name_list(file, name) + return "Added to abandoned list" -def collect_click(name, types): - file = os.path.join(inspiration_system_path, types + "_faverites.txt") - print(file) +def collect_click(name): + file = os.path.join(inspiration_system_path, "faverites.txt") name_list = read_name_list(file) - print(name_list) if name not in name_list: save_name_list(file, name) + return "Added to faverite list" -def moveout_click(name, types): - file = os.path.join(inspiration_system_path, types + "_faverites.txt") +def moveout_click(name, source): + if source == "Abandoned": + file = os.path.join(inspiration_system_path, "abandoned.txt") + if source == "Favorites": + file = os.path.join(inspiration_system_path, "faverites.txt") + else: + return None name_list = read_name_list(file) - if name not in name_list: - save_name_list(file, name) + os.remove(file) + with open(file, "a") as f: + for a in name_list: + if a != name: + f.write(a) + return "Moved out {name} from {source} list" def source_change(source): - if source == "Abandoned" or source == "Favorites": - return gradio.Button.update(visible=True, value=f"Move out {source}") + if source in ["Abandoned", "Favorites"]: + return gradio.update(visible=True), [] else: - return gradio.Button.update(visible=False) + return gradio.update(visible=False), [] +def add_to_prompt(name, prompt): + print(name, prompt) + name = os.path.basename(name) + return prompt + "," + name -def ui(gr, opts): +def ui(gr, opts, txt2img_prompt, img2img_prompt): with gr.Blocks(analytics_enabled=False) as inspiration: - flag = os.path.exists(inspiration_path) + flag = os.path.exists(opts.inspiration_dir) if flag: - types = os.listdir(inspiration_path) - types = [x for x in types if x != "system"] + types = get_types_list() flag = len(types) > 0 - if not flag: - os.mkdir(inspiration_path) + else: + os.makedirs(opts.inspiration_dir) + if not flag: gr.HTML(""" -
" +

To activate inspiration function, you need get "inspiration" images first.


+ You can create these images by run "Create inspiration images" script in txt2img page,
you can get the artists or art styles list from here
+ https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
+ download these files, and select these files in the "Create inspiration images" script UI
+ There about 6000 artists and art styles in these files.
This takes server hours depending on your GPU type and how many pictures you generate for each artist/style +
I suggest at least four images for each


+

You can also download generated pictures from here:


+ https://huggingface.co/datasets/yfszzx/inspiration
+ unzip the file to the project directory of webui
+ and restart webui, and enjoy the joy of creation!
""") return inspiration if not os.path.exists(inspiration_system_path): os.mkdir(inspiration_system_path) - gallery, names = get_inspiration_images("Exclude abandoned", types[0]) with gr.Row(): with gr.Column(scale=2): - inspiration_gallery = gr.Gallery(gallery, show_label=False, elem_id="inspiration_gallery").style(grid=5, height='auto') + inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto') with gr.Column(scale=1): - types = gr.Dropdown(choices=types, value=types[0], label="Type", visible=len(types) > 1) + print(types) + types = gr.CheckboxGroup(choices=types, value=types) + keyword = gr.Textbox("", label="Key word") with gr.Row(): source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") - get_inspiration = gr.Button("Get inspiration") + get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button") name = gr.Textbox(show_label=False, interactive=False) with gr.Row(): send_to_txt2img = gr.Button('to txt2img') send_to_img2img = gr.Button('to img2img') - style_gallery = gr.Gallery(show_label=False, elem_id="inspiration_style_gallery").style(grid=2, height='auto') - + style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') collect = gr.Button('Collect') - give_up = gr.Button("Don't show any more") + give_up = gr.Button("Don't show again") moveout = gr.Button("Move out", visible=False) - with gr.Row(): + warning = gr.HTML() + with gr.Row(visible=False): select_button = gr.Button('set button', elem_id="inspiration_select_button") - name_list = gr.State(names) - source.change(source_change, inputs=[source], outputs=[moveout]) - get_inspiration.click(get_inspiration_images, inputs=[source, types], outputs=[inspiration_gallery, name_list]) - select_button.click(select_click, _js="inspiration_selected", inputs=[name, types, name_list], outputs=[name, style_gallery]) - give_up.click(give_up_click, inputs=[name, types], outputs=None) - collect.click(collect_click, inputs=[name, types], outputs=None) + name_list = gr.State() + + get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list, keyword]) + source.change(source_change, inputs=[source], outputs=[moveout, style_gallery]) + source.change(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning]) + give_up.click(give_up_click, inputs=[name], outputs=[warning]) + collect.click(collect_click, inputs=[name], outputs=[warning]) + moveout.click(moveout_click, inputs=[name, source], outputs=[warning]) + send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt]) + send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt]) + send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None) + send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None) return inspiration diff --git a/modules/shared.py b/modules/shared.py index ae033710..564b1b8d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -316,6 +316,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) +options_templates.update(options_section(('inspiration', "Inspiration"), { + "inspiration_dir": OptionInfo("inspiration", "Directory of inspiration", component_args=hide_dirs), + "inspiration_max_samples": OptionInfo(4, "Maximum number of samples, used to determine which folders to skip when continue running the create script", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), + "inspiration_rows_num": OptionInfo(4, "Rows of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), + "inspiration_cols_num": OptionInfo(8, "Columns of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), +})) class Options: data = None diff --git a/modules/ui.py b/modules/ui.py index 6a0a3c3b..b651eb9c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1180,7 +1180,7 @@ def create_ui(wrap_gradio_gpu_call): } browser_interface = images_history.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) - inspiration_interface = inspiration.ui(gr, opts) + inspiration_interface = inspiration.ui(gr, opts, txt2img_prompt, img2img_prompt) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): diff --git a/webui.py b/webui.py index 5923905f..5ccae715 100644 --- a/webui.py +++ b/webui.py @@ -72,6 +72,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) def initialize(): + modules.scripts.load_scripts(os.path.join(script_path, "scripts")) if cmd_opts.ui_debug_mode: class enmpty(): name = None @@ -84,7 +85,7 @@ def initialize(): shared.face_restorers.append(modules.face_restoration.FaceRestoration()) modelloader.load_upscalers() - modules.scripts.load_scripts(os.path.join(script_path, "scripts")) + shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) -- cgit v1.2.3 From 9ba372de90df81c4f1e992d8b33ae17c6630de95 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 13:55:42 -0500 Subject: initial work on getting prompts cleared on the backend and synchronizing token counter --- javascript/ui.js | 10 +++++++--- modules/ui.py | 29 +++++++++++++++++++---------- 2 files changed, 26 insertions(+), 13 deletions(-) (limited to 'modules') diff --git a/javascript/ui.js b/javascript/ui.js index f19af550..a0f01d10 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -162,9 +162,13 @@ function selected_tab_id() { } -function trash_prompt(_,_, is_img2img) { +function trash_prompt(_, confirmed) { -if(!confirm("Delete prompt?")) return false +if(confirm("Delete prompt?")) { + confirmed = true +} else { +return [_, confirmed] +} if(selected_tab_id() == "tab_txt2img") { gradioApp().querySelector("#txt2img_prompt > label > textarea").value = ""; @@ -178,7 +182,7 @@ if(!confirm("Delete prompt?")) return false update_token_counter("txt2img_token_button") } - return true + return [_, confirmed] } diff --git a/modules/ui.py b/modules/ui.py index d2cb528e..2748a2e0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,15 +429,16 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -# setup button for clearing prompt input boxes on client side of webui -def connect_trash_prompt(dummy_component, button, is_img2img): +def clear_prompt(prompt): + print(f"type: '{prompt}'") + print(prompt) + + # update_token_counter(prompt, steps) + return "" + +def connect_trash_prompt(prompt, confirmed): + if(confirmed): clear_prompt(prompt) - button.click( - fn=lambda: print("Clearing prompt"), - _js="trash_prompt", - inputs=[], - outputs=[], - ) def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): """ Connects a 'reuse (sub)seed' button's click event so that it copies last used @@ -640,7 +641,16 @@ def create_ui(wrap_gradio_gpu_call): dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - connect_trash_prompt(dummy_component, trash_prompt_button, False) + + + trash_prompt_button.click( + # fn=lambda: print("Clearing prompt"), + _js="trash_prompt", + fn=lambda: clear_prompt(), + inputs=[txt2img_prompt, dummy_component], + outputs=[txt2img_prompt, dummy_component], + ) + with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): @@ -848,7 +858,6 @@ def create_ui(wrap_gradio_gpu_call): img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) - connect_trash_prompt(dummy_component,trash_prompt_button, True) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.3 From ee0505dd0092ae7073b77aba93a858bda000dc60 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 14:24:14 -0500 Subject: only delete prompt on back end and remove client-side deletion --- javascript/ui.js | 6 ------ modules/ui.py | 29 +++++++++++++++-------------- 2 files changed, 15 insertions(+), 20 deletions(-) (limited to 'modules') diff --git a/javascript/ui.js b/javascript/ui.js index a0f01d10..29306abe 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -171,14 +171,8 @@ return [_, confirmed] } if(selected_tab_id() == "tab_txt2img") { - gradioApp().querySelector("#txt2img_prompt > label > textarea").value = ""; - gradioApp().querySelector("#txt2img_neg_prompt > label > textarea").value = ""; - update_token_counter("img2img_token_button") } else { - gradioApp().querySelector("#img2img_prompt > label > textarea").value = ""; - gradioApp().querySelector("#img2img_neg_prompt > label > textarea").value = ""; - update_token_counter("txt2img_token_button") } diff --git a/modules/ui.py b/modules/ui.py index 2748a2e0..90c338da 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,15 +429,21 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -def clear_prompt(prompt): - print(f"type: '{prompt}'") - print(prompt) - # update_token_counter(prompt, steps) - return "" +def connect_trash_prompt(_, confirmed): + if(confirmed): + # update_token_counter(prompt, steps) + return ["", confirmed] -def connect_trash_prompt(prompt, confirmed): - if(confirmed): clear_prompt(prompt) +def trash_prompt_click(button, prompt): + dummy_component = gradio.Label(visible=False) + + button.click( + _js="trash_prompt", + fn=connect_trash_prompt, + inputs=[prompt, dummy_component], + outputs=[prompt, dummy_component], + ) def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): @@ -643,13 +649,7 @@ def create_ui(wrap_gradio_gpu_call): txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - trash_prompt_button.click( - # fn=lambda: print("Clearing prompt"), - _js="trash_prompt", - fn=lambda: clear_prompt(), - inputs=[txt2img_prompt, dummy_component], - outputs=[txt2img_prompt, dummy_component], - ) + trash_prompt_click(trash_prompt_button, txt2img_prompt) with gr.Row(elem_id='txt2img_progress_row'): @@ -858,6 +858,7 @@ def create_ui(wrap_gradio_gpu_call): img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) + trash_prompt_click(trash_prompt_button, img2img_prompt) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) -- cgit v1.2.3 From de70ddaf58fae98c561738a54f574abfa14cd8d1 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 15:00:35 -0500 Subject: update token counter when clearing prompt --- javascript/ui.js | 4 ++-- modules/ui.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/javascript/ui.js b/javascript/ui.js index 29306abe..acd57565 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -162,7 +162,7 @@ function selected_tab_id() { } -function trash_prompt(_, confirmed) { +function trash_prompt(_, confirmed,_steps) { if(confirm("Delete prompt?")) { confirmed = true @@ -176,7 +176,7 @@ return [_, confirmed] update_token_counter("txt2img_token_button") } - return [_, confirmed] + return [_, confirmed,_steps] } diff --git a/modules/ui.py b/modules/ui.py index 90c338da..d3a89bf7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -430,19 +430,16 @@ def create_seed_inputs(): -def connect_trash_prompt(_, confirmed): +def connect_trash_prompt(_prompt, confirmed, _token_counter): if(confirmed): - # update_token_counter(prompt, steps) - return ["", confirmed] - -def trash_prompt_click(button, prompt): - dummy_component = gradio.Label(visible=False) + return ["", confirmed, update_token_counter("", 1)] +def trash_prompt_click(button, prompt, _dummy_confirmed, token_counter): button.click( _js="trash_prompt", fn=connect_trash_prompt, - inputs=[prompt, dummy_component], - outputs=[prompt, dummy_component], + inputs=[prompt, _dummy_confirmed, token_counter], + outputs=[prompt, _dummy_confirmed, token_counter], ) @@ -649,7 +646,6 @@ def create_ui(wrap_gradio_gpu_call): txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - trash_prompt_click(trash_prompt_button, txt2img_prompt) with gr.Row(elem_id='txt2img_progress_row'): @@ -720,6 +716,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + trash_prompt_click(trash_prompt_button, txt2img_prompt, dummy_component, token_counter) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), @@ -858,7 +855,6 @@ def create_ui(wrap_gradio_gpu_call): img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) - trash_prompt_click(trash_prompt_button, img2img_prompt) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -958,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + trash_prompt_click(trash_prompt_button, img2img_prompt, dummy_component, token_counter) img2img_prompt_img.change( fn=modules.images.image_data, -- cgit v1.2.3 From 9e40520f00d836cfa93187f7f1e81e2a7bd100b9 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 15:13:12 -0500 Subject: refactor internal terminology to use 'clear' instead of 'trash' like #2728 --- javascript/ui.js | 2 +- modules/shared.py | 2 +- modules/ui.py | 22 +++++++++++----------- 3 files changed, 13 insertions(+), 13 deletions(-) (limited to 'modules') diff --git a/javascript/ui.js b/javascript/ui.js index acd57565..45d93a5c 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -162,7 +162,7 @@ function selected_tab_id() { } -function trash_prompt(_, confirmed,_steps) { +function clear_prompt(_, confirmed,_steps) { if(confirm("Delete prompt?")) { confirmed = true diff --git a/modules/shared.py b/modules/shared.py index 1585d532..ab5a0e9a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -317,7 +317,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), - "trash_prompt_visible": OptionInfo(True, "Show trash prompt button"), + "clear_prompt_visible": OptionInfo(True, "Show clear prompt button"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index d3a89bf7..31150800 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -88,7 +88,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 -trash_prompt_symbol = '\U0001F5D1' # +clear_prompt_symbol = '\U0001F5D1' # 🗑️ def plaintext_to_html(text): @@ -430,14 +430,14 @@ def create_seed_inputs(): -def connect_trash_prompt(_prompt, confirmed, _token_counter): +def clear_prompt(_prompt, confirmed, _token_counter): if(confirmed): return ["", confirmed, update_token_counter("", 1)] -def trash_prompt_click(button, prompt, _dummy_confirmed, token_counter): +def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter): button.click( - _js="trash_prompt", - fn=connect_trash_prompt, + _js="clear_prompt", + fn=clear_prompt, inputs=[prompt, _dummy_confirmed, token_counter], outputs=[prompt, _dummy_confirmed, token_counter], ) @@ -518,7 +518,7 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt", visible=opts.trash_prompt_visible) + clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id="clear_prompt", visible=opts.clear_prompt_visible) token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") @@ -559,7 +559,7 @@ def create_toprow(is_img2img): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style2.save_to_config = True - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, trash_prompt + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, clear_prompt_button def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -640,7 +640,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\ txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\ - token_button, trash_prompt_button = create_toprow(is_img2img=False) + token_button, clear_prompt_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -716,7 +716,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - trash_prompt_click(trash_prompt_button, txt2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, txt2img_prompt, dummy_component, token_counter) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), @@ -853,7 +853,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\ img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ - token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) + token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): @@ -954,7 +954,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - trash_prompt_click(trash_prompt_button, img2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, img2img_prompt, dummy_component, token_counter) img2img_prompt_img.change( fn=modules.images.image_data, -- cgit v1.2.3 From 0c7cf08b3d27a61bab4cd8b16f8be8ae74879423 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 15:32:26 -0500 Subject: some doc and formatting --- modules/ui.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 31150800..b26cf10a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -88,7 +88,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ +clear_prompt_symbol = '\U0001F5D1' # 🗑️ def plaintext_to_html(text): @@ -429,12 +429,14 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - def clear_prompt(_prompt, confirmed, _token_counter): - if(confirmed): - return ["", confirmed, update_token_counter("", 1)] + """Given confirmation from a user on the client-side, go ahead with clearing prompt""" + if confirmed: + return ["", confirmed, update_token_counter("", 1)] + def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter): + """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" button.click( _js="clear_prompt", fn=clear_prompt, @@ -518,7 +520,12 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id="clear_prompt", visible=opts.clear_prompt_visible) + + clear_prompt_button = gr.Button( + value=clear_prompt_symbol, + elem_id="clear_prompt", + visible=opts.clear_prompt_visible + ) token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") -- cgit v1.2.3 From 57eb54b838faa383c10079e1bb5471b7bee6a695 Mon Sep 17 00:00:00 2001 From: Extraltodeus Date: Sat, 22 Oct 2022 00:11:07 +0200 Subject: implement CUDA device selection by ID --- modules/devices.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index eb422583..8a159282 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,7 +1,6 @@ +import sys, os, shlex import contextlib - import torch - from modules import errors # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility @@ -9,10 +8,26 @@ has_mps = getattr(torch, 'has_mps', False) cpu = torch.device("cpu") +def extract_device_id(args, name): + for x in range(len(args)): + if name in args[x]: return args[x+1] + return None def get_optimal_device(): if torch.cuda.is_available(): - return torch.device("cuda") + # CUDA device selection support: + if "shared" not in sys.modules: + commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop. + sys.argv += shlex.split(commandline_args) + device_id = extract_device_id(sys.argv, '--device-id') + else: + device_id = shared.cmd_opts.device_id + + if device_id is not None: + cuda_device = f"cuda:{device_id}" + return torch.device(cuda_device) + else: + return torch.device("cuda") if has_mps: return torch.device("mps") -- cgit v1.2.3 From 29bfacd63cb5c73b9643d94f255cca818fd49d9c Mon Sep 17 00:00:00 2001 From: Extraltodeus Date: Sat, 22 Oct 2022 00:12:46 +0200 Subject: implement CUDA device selection, --device-id arg --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 41d7f08e..03032a47 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -80,6 +80,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) cmd_opts = parser.parse_args() restricted_opts = [ -- cgit v1.2.3 From 700340448baa7412c7cc5ff3d1349ac79ee8ed0c Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 17:24:04 -0500 Subject: forgot to clear neg prompt after moving to back. Add tooltip to hints --- javascript/hints.js | 1 + javascript/ui.js | 4 ++-- modules/ui.py | 14 +++++++------- 3 files changed, 10 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/javascript/hints.js b/javascript/hints.js index a1fcc93b..54c8c238 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -17,6 +17,7 @@ titles = { "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u{1f4c2}": "Open images output directory", "\u{1f4be}": "Save style", + "\U0001F5D1": "Clear prompt" "\u{1f4cb}": "Apply selected styles to current prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", diff --git a/javascript/ui.js b/javascript/ui.js index 45d93a5c..6c99824b 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -162,7 +162,7 @@ function selected_tab_id() { } -function clear_prompt(_, confirmed,_steps) { +function clear_prompt(_, _prompt_neg, confirmed,_steps) { if(confirm("Delete prompt?")) { confirmed = true @@ -176,7 +176,7 @@ return [_, confirmed] update_token_counter("txt2img_token_button") } - return [_, confirmed,_steps] + return [_, _prompt_neg, confirmed,_steps] } diff --git a/modules/ui.py b/modules/ui.py index b26cf10a..25aeba3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,19 +429,19 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -def clear_prompt(_prompt, confirmed, _token_counter): +def clear_prompt(_prompt, _prompt_neg, confirmed, _token_counter): """Given confirmation from a user on the client-side, go ahead with clearing prompt""" if confirmed: - return ["", confirmed, update_token_counter("", 1)] + return ["", "", confirmed, update_token_counter("", 1)] -def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter): +def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter): """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" button.click( _js="clear_prompt", fn=clear_prompt, - inputs=[prompt, _dummy_confirmed, token_counter], - outputs=[prompt, _dummy_confirmed, token_counter], + inputs=[prompt, prompt_neg, _dummy_confirmed, token_counter], + outputs=[prompt, prompt_neg, _dummy_confirmed, token_counter], ) @@ -723,7 +723,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - connect_clear_prompt(clear_prompt_button, txt2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, txt2img_prompt, txt2img_negative_prompt, dummy_component, token_counter) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), @@ -961,7 +961,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - connect_clear_prompt(clear_prompt_button, img2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, img2img_prompt, img2img_negative_prompt, dummy_component, token_counter) img2img_prompt_img.change( fn=modules.images.image_data, -- cgit v1.2.3 From 40ddb6df61564684263c7442bacf61efe3882b87 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sat, 22 Oct 2022 10:16:22 +0800 Subject: inspiration perfected --- javascript/inspiration.js | 19 +++++++------ modules/inspiration.py | 71 ++++++++++++++++++++++++++--------------------- 2 files changed, 49 insertions(+), 41 deletions(-) (limited to 'modules') diff --git a/javascript/inspiration.js b/javascript/inspiration.js index 791a80c9..39844544 100644 --- a/javascript/inspiration.js +++ b/javascript/inspiration.js @@ -1,5 +1,5 @@ function public_image_index_in_gallery(item, gallery){ - var imgs = gallery.querySelectorAll("img.h-full") + var imgs = gallery.querySelectorAll("img.h-full") var index; var i = 0; imgs.forEach(function(e){ @@ -7,18 +7,23 @@ function public_image_index_in_gallery(item, gallery){ index = i; i += 1; }); - var num = imgs.length / 2 - index = (index < num) ? index : (index - num) + var all_imgs = gallery.querySelectorAll("img") + if (all_imgs.length > imgs.length){ + var num = imgs.length / 2 + index = (index < num) ? index : (index - num) + } return index; } function inspiration_selected(name, name_list){ var btn = gradioApp().getElementById("inspiration_select_button") return [gradioApp().getElementById("inspiration_select_button").getAttribute("img-index")]; -} +} + function inspiration_click_get_button(){ gradioApp().getElementById("inspiration_get_button").click(); } + var inspiration_image_click = function(){ var index = public_image_index_in_gallery(this, gradioApp().getElementById("inspiration_gallery")); var btn = gradioApp().getElementById("inspiration_select_button"); @@ -32,16 +37,12 @@ document.addEventListener("DOMContentLoaded", function() { if (gallery) { var node = gallery.querySelector(".absolute.backdrop-blur.h-full") if (node) { - node.style.display = "None"; //parentNode.removeChild(node) + node.style.display = "None"; } gallery.querySelectorAll('img').forEach(function(e){ e.onclick = inspiration_image_click }); - } - - }); mutationObserver.observe( gradioApp(), { childList:true, subtree:true }); - }); diff --git a/modules/inspiration.py b/modules/inspiration.py index f72ebf3a..319183ab 100644 --- a/modules/inspiration.py +++ b/modules/inspiration.py @@ -13,7 +13,7 @@ def read_name_list(file, types=None, keyword=None): line = line.rstrip("\n") if types is not None: dirname = os.path.split(line) - if dirname[0] in types and keyword in dirname[1]: + if dirname[0] in types and keyword in dirname[1].lower(): ret.append(line) else: ret.append(line) @@ -21,8 +21,10 @@ def read_name_list(file, types=None, keyword=None): return ret def save_name_list(file, name): - with open(file, "a") as f: - f.write(name + "\n") + name_list = read_name_list(file) + if name not in name_list: + with open(file, "a") as f: + f.write(name + "\n") def get_types_list(): files = os.listdir(opts.inspiration_dir) @@ -39,20 +41,20 @@ def get_types_list(): return types def get_inspiration_images(source, types, keyword): + keyword = keyword.strip(" ").lower() get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num) if source == "Favorites": names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword) names = random.sample(names, get_num) if len(names) > get_num else names elif source == "Abandoned": names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) - print(names) names = random.sample(names, get_num) if len(names) > get_num else names elif source == "Exclude abandoned": abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) all_names = [] for tp in types: name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) - all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()] if len(all_names) > get_num: names = [] @@ -66,14 +68,14 @@ def get_inspiration_images(source, types, keyword): all_names = [] for tp in types: name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) - all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()] names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names image_list = [] for a in names: image_path = os.path.join(opts.inspiration_dir, a) images = os.listdir(image_path) image_list.append((os.path.join(image_path, random.choice(images)), a)) - return image_list, names, "" + return image_list, names def select_click(index, name_list): name = name_list[int(index)] @@ -83,22 +85,18 @@ def select_click(index, name_list): def give_up_click(name): file = os.path.join(inspiration_system_path, "abandoned.txt") - name_list = read_name_list(file) - if name not in name_list: - save_name_list(file, name) + save_name_list(file, name) return "Added to abandoned list" def collect_click(name): file = os.path.join(inspiration_system_path, "faverites.txt") - name_list = read_name_list(file) - if name not in name_list: - save_name_list(file, name) + save_name_list(file, name) return "Added to faverite list" def moveout_click(name, source): if source == "Abandoned": file = os.path.join(inspiration_system_path, "abandoned.txt") - if source == "Favorites": + elif source == "Favorites": file = os.path.join(inspiration_system_path, "faverites.txt") else: return None @@ -107,8 +105,8 @@ def moveout_click(name, source): with open(file, "a") as f: for a in name_list: if a != name: - f.write(a) - return "Moved out {name} from {source} list" + f.write(a + "\n") + return f"Moved out {name} from {source} list" def source_change(source): if source in ["Abandoned", "Favorites"]: @@ -116,10 +114,12 @@ def source_change(source): else: return gradio.update(visible=False), [] def add_to_prompt(name, prompt): - print(name, prompt) name = os.path.basename(name) return prompt + "," + name +def clear_keyword(): + return "" + def ui(gr, opts, txt2img_prompt, img2img_prompt): with gr.Blocks(analytics_enabled=False) as inspiration: flag = os.path.exists(opts.inspiration_dir) @@ -132,15 +132,15 @@ def ui(gr, opts, txt2img_prompt, img2img_prompt): gr.HTML("""

To activate inspiration function, you need get "inspiration" images first.


You can create these images by run "Create inspiration images" script in txt2img page,
you can get the artists or art styles list from here
- https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
+ https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
download these files, and select these files in the "Create inspiration images" script UI
There about 6000 artists and art styles in these files.
This takes server hours depending on your GPU type and how many pictures you generate for each artist/style
I suggest at least four images for each


You can also download generated pictures from here:


- https://huggingface.co/datasets/yfszzx/inspiration
+ https://huggingface.co/datasets/yfszzx/inspiration
unzip the file to the project directory of webui
and restart webui, and enjoy the joy of creation!
- """) + """) return inspiration if not os.path.exists(inspiration_system_path): os.mkdir(inspiration_system_path) @@ -148,35 +148,42 @@ def ui(gr, opts, txt2img_prompt, img2img_prompt): with gr.Column(scale=2): inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto') with gr.Column(scale=1): - print(types) types = gr.CheckboxGroup(choices=types, value=types) - keyword = gr.Textbox("", label="Key word") - with gr.Row(): - source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") - get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button") - name = gr.Textbox(show_label=False, interactive=False) with gr.Row(): + source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") + keyword = gr.Textbox("", label="Key word") + get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button") + name = gr.Textbox(show_label=False, interactive=False) + with gr.Row(): send_to_txt2img = gr.Button('to txt2img') send_to_img2img = gr.Button('to img2img') style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') - collect = gr.Button('Collect') - give_up = gr.Button("Don't show again") - moveout = gr.Button("Move out", visible=False) warning = gr.HTML() + with gr.Row(): + collect = gr.Button('Collect') + give_up = gr.Button("Don't show again") + moveout = gr.Button("Move out", visible=False) + with gr.Row(visible=False): select_button = gr.Button('set button', elem_id="inspiration_select_button") name_list = gr.State() - get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list, keyword]) - source.change(source_change, inputs=[source], outputs=[moveout, style_gallery]) - source.change(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list]) keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + source.change(source_change, inputs=[source], outputs=[moveout, style_gallery]) + source.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword]) + types.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword]) + select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning]) give_up.click(give_up_click, inputs=[name], outputs=[warning]) collect.click(collect_click, inputs=[name], outputs=[warning]) moveout.click(moveout_click, inputs=[name, source], outputs=[warning]) + moveout.click(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt]) send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt]) + send_to_txt2img.click(collect_click, inputs=[name], outputs=[warning]) + send_to_img2img.click(collect_click, inputs=[name], outputs=[warning]) send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None) send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None) return inspiration -- cgit v1.2.3 From d93ea5cdeb2fd3607b7265271ccab2c9bf4c1156 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sat, 22 Oct 2022 10:21:21 +0800 Subject: inspiration perfected --- modules/inspiration.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/inspiration.py b/modules/inspiration.py index 319183ab..94ff139a 100644 --- a/modules/inspiration.py +++ b/modules/inspiration.py @@ -73,8 +73,11 @@ def get_inspiration_images(source, types, keyword): image_list = [] for a in names: image_path = os.path.join(opts.inspiration_dir, a) - images = os.listdir(image_path) - image_list.append((os.path.join(image_path, random.choice(images)), a)) + images = os.listdir(image_path) + if len(images) > 0: + image_list.append((os.path.join(image_path, random.choice(images)), a)) + else: + print(image_path) return image_list, names def select_click(index, name_list): -- cgit v1.2.3 From 67b78f0ea6f196bfdca49932da062631bb40d0b1 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sat, 22 Oct 2022 10:29:23 +0800 Subject: inspiration perfected --- modules/inspiration.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/inspiration.py b/modules/inspiration.py index 94ff139a..29cf8297 100644 --- a/modules/inspiration.py +++ b/modules/inspiration.py @@ -160,12 +160,13 @@ def ui(gr, opts, txt2img_prompt, img2img_prompt): with gr.Row(): send_to_txt2img = gr.Button('to txt2img') send_to_img2img = gr.Button('to img2img') - style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') - warning = gr.HTML() - with gr.Row(): collect = gr.Button('Collect') give_up = gr.Button("Don't show again") - moveout = gr.Button("Move out", visible=False) + moveout = gr.Button("Move out", visible=False) + warning = gr.HTML() + style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') + + with gr.Row(visible=False): select_button = gr.Button('set button', elem_id="inspiration_select_button") -- cgit v1.2.3 From 2b91251637078e04472c91a06a8d9c4db9c1dcf0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 12:23:45 +0300 Subject: removed aesthetic gradients as built-in added support for extensions --- .gitignore | 2 +- extensions/put extension here.txt | 0 modules/aesthetic_clip.py | 241 -------------------------------------- modules/images_history.py | 2 +- modules/img2img.py | 5 +- modules/processing.py | 35 ++++-- modules/script_callbacks.py | 42 +++++++ modules/scripts.py | 210 ++++++++++++++++++++++++--------- modules/sd_hijack.py | 1 - modules/sd_models.py | 7 +- modules/shared.py | 19 --- modules/txt2img.py | 5 +- modules/ui.py | 83 ++----------- webui.py | 7 +- 14 files changed, 249 insertions(+), 410 deletions(-) create mode 100644 extensions/put extension here.txt delete mode 100644 modules/aesthetic_clip.py create mode 100644 modules/script_callbacks.py (limited to 'modules') diff --git a/.gitignore b/.gitignore index f9c3357c..2f1e08ed 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,4 @@ __pycache__ notification.mp3 /SwinIR /textual_inversion -.vscode \ No newline at end of file +.vscode diff --git a/extensions/put extension here.txt b/extensions/put extension here.txt new file mode 100644 index 00000000..e69de29b diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py deleted file mode 100644 index 8c828541..00000000 --- a/modules/aesthetic_clip.py +++ /dev/null @@ -1,241 +0,0 @@ -import copy -import itertools -import os -from pathlib import Path -import html -import gc - -import gradio as gr -import torch -from PIL import Image -from torch import optim - -from modules import shared -from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer -from tqdm.auto import tqdm, trange -from modules.shared import opts, device - - -def get_all_images_in_folder(folder): - return [os.path.join(folder, f) for f in os.listdir(folder) if - os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)] - - -def check_is_valid_image_file(filename): - return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp")) - - -def batched(dataset, total, n=1): - for ndx in range(0, total, n): - yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] - - -def iter_to_batched(iterable, n=1): - it = iter(iterable) - while True: - chunk = tuple(itertools.islice(it, n)) - if not chunk: - return - yield chunk - - -def create_ui(): - import modules.ui - - with gr.Group(): - with gr.Accordion("Open for Clip Aesthetic!", open=False): - with gr.Row(): - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", - value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) - - with gr.Row(): - aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', - placeholder="Aesthetic learning rate", value="0.0001") - aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) - aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()), - label="Aesthetic imgs embedding", - value="None") - - modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings") - - with gr.Row(): - aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', - placeholder="This text is used to rotate the feature space of the imgs embs", - value="") - aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01, - value=0.1) - aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) - - return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative - - -aesthetic_clip_model = None - - -def aesthetic_clip(): - global aesthetic_clip_model - - if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path: - aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path) - aesthetic_clip_model.cpu() - - return aesthetic_clip_model - - -def generate_imgs_embd(name, folder, batch_size): - model = aesthetic_clip().to(device) - processor = CLIPProcessor.from_pretrained(model.name_or_path) - - with torch.no_grad(): - embs = [] - for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size), - desc=f"Generating embeddings for {name}"): - if shared.state.interrupted: - break - inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device) - outputs = model.get_image_features(**inputs).cpu() - embs.append(torch.clone(outputs)) - inputs.to("cpu") - del inputs, outputs - - embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) - - # The generated embedding will be located here - path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") - torch.save(embs, path) - - model.cpu() - del processor - del embs - gc.collect() - torch.cuda.empty_cache() - res = f""" - Done generating embedding for {name}! - Aesthetic embedding saved to {html.escape(path)} - """ - shared.update_aesthetic_embeddings() - return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", - value="None"), \ - gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), - label="Imgs embedding", - value="None"), res, "" - - -def slerp(low, high, val): - low_norm = low / torch.norm(low, dim=1, keepdim=True) - high_norm = high / torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm * high_norm).sum(1)) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high - return res - - -class AestheticCLIP: - def __init__(self): - self.skip = False - self.aesthetic_steps = 0 - self.aesthetic_weight = 0 - self.aesthetic_lr = 0 - self.slerp = False - self.aesthetic_text_negative = "" - self.aesthetic_slerp_angle = 0 - self.aesthetic_imgs_text = "" - - self.image_embs_name = None - self.image_embs = None - self.load_image_embs(None) - - def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, - aesthetic_slerp=True, aesthetic_imgs_text="", - aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False): - self.aesthetic_imgs_text = aesthetic_imgs_text - self.aesthetic_slerp_angle = aesthetic_slerp_angle - self.aesthetic_text_negative = aesthetic_text_negative - self.slerp = aesthetic_slerp - self.aesthetic_lr = aesthetic_lr - self.aesthetic_weight = aesthetic_weight - self.aesthetic_steps = aesthetic_steps - self.load_image_embs(image_embs_name) - - if self.image_embs_name is not None: - p.extra_generation_params.update({ - "Aesthetic LR": aesthetic_lr, - "Aesthetic weight": aesthetic_weight, - "Aesthetic steps": aesthetic_steps, - "Aesthetic embedding": self.image_embs_name, - "Aesthetic slerp": aesthetic_slerp, - "Aesthetic text": aesthetic_imgs_text, - "Aesthetic text negative": aesthetic_text_negative, - "Aesthetic slerp angle": aesthetic_slerp_angle, - }) - - def set_skip(self, skip): - self.skip = skip - - def load_image_embs(self, image_embs_name): - if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": - image_embs_name = None - self.image_embs_name = None - if image_embs_name is not None and self.image_embs_name != image_embs_name: - self.image_embs_name = image_embs_name - self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) - self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) - self.image_embs.requires_grad_(False) - - def __call__(self, z, remade_batch_tokens): - if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None: - tokenizer = shared.sd_model.cond_stage_model.tokenizer - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [ - [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in - remade_batch_tokens] - - tokens = torch.asarray(remade_batch_tokens).to(device) - - model = copy.deepcopy(aesthetic_clip()).to(device) - model.requires_grad_(True) - if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: - text_embs_2 = model.get_text_features( - **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) - if self.aesthetic_text_negative: - text_embs_2 = self.image_embs - text_embs_2 - text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) - img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) - else: - img_embs = self.image_embs - - with torch.enable_grad(): - - # We optimize the model to maximize the similarity - optimizer = optim.Adam( - model.text_model.parameters(), lr=self.aesthetic_lr - ) - - for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"): - text_embs = model.get_text_features(input_ids=tokens) - text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ img_embs.T - loss = -sim - optimizer.zero_grad() - loss.mean().backward() - optimizer.step() - - zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - if opts.CLIP_stop_at_last_layers > 1: - zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] - zn = model.text_model.final_layer_norm(zn) - else: - zn = zn.last_hidden_state - model.cpu() - del model - gc.collect() - torch.cuda.empty_cache() - zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1) - if self.slerp: - z = slerp(z, zn, self.aesthetic_weight) - else: - z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight - - return z diff --git a/modules/images_history.py b/modules/images_history.py index 78fd0543..bc5cf11f 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -310,7 +310,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): forward = gr.Button('Prev batch') backward = gr.Button('Next batch') with gr.Column(scale=3): - load_info = gr.HTML(visible=not custom_dir) + load_info = gr.HTML(visible=not custom_dir) with gr.Row(visible=False) as warning: warning_box = gr.Textbox("Message", interactive=False) diff --git a/modules/img2img.py b/modules/img2img.py index eea5199b..8d9f7cf9 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_inpaint = mode == 1 is_batch = mode == 2 @@ -109,7 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpainting_mask_invert=inpainting_mask_invert, ) - shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index ff1ec4c9..372489f7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -104,6 +104,12 @@ class StableDiffusionProcessing(): self.seed_resize_from_h = 0 self.seed_resize_from_w = 0 + self.scripts = None + self.script_args = None + self.all_prompts = None + self.all_seeds = None + self.all_subseeds = None + def init(self, all_prompts, all_seeds, all_subseeds): pass @@ -350,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed: shared.prompt_styles.apply_styles(p) if type(p.prompt) == list: - all_prompts = p.prompt + p.all_prompts = p.prompt else: - all_prompts = p.batch_size * p.n_iter * [p.prompt] + p.all_prompts = p.batch_size * p.n_iter * [p.prompt] if type(seed) == list: - all_seeds = seed + p.all_seeds = seed else: - all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] + p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))] if type(subseed) == list: - all_subseeds = subseed + p.all_subseeds = subseed else: - all_subseeds = [int(subseed) + x for x in range(len(all_prompts))] + p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] def infotext(iteration=0, position_in_batch=0): - return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) + return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() + if p.scripts is not None: + p.scripts.run_alwayson_scripts(p) + infotexts = [] output_images = [] with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): - p.init(all_prompts, all_seeds, all_subseeds) + p.init(p.all_prompts, p.all_seeds, p.all_subseeds) if state.job_count == -1: state.job_count = p.n_iter @@ -387,9 +396,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.interrupted: break - prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size] - seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size] - subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] + seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] + subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] if (len(prompts) == 0): break @@ -490,10 +499,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: index_of_first_image = 1 if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) devices.torch_gc() - return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py new file mode 100644 index 00000000..866b7acd --- /dev/null +++ b/modules/script_callbacks.py @@ -0,0 +1,42 @@ + +callbacks_model_loaded = [] +callbacks_ui_tabs = [] + + +def clear_callbacks(): + callbacks_model_loaded.clear() + callbacks_ui_tabs.clear() + + +def model_loaded_callback(sd_model): + for callback in callbacks_model_loaded: + callback(sd_model) + + +def ui_tabs_callback(): + res = [] + + for callback in callbacks_ui_tabs: + res += callback() or [] + + return res + + +def on_model_loaded(callback): + """register a function to be called when the stable diffusion model is created; the model is + passed as an argument""" + callbacks_model_loaded.append(callback) + + +def on_ui_tabs(callback): + """register a function to be called when the UI is creating new tabs. + The function must either return a None, which means no new tabs to be added, or a list, where + each element is a tuple: + (gradio_component, title, elem_id) + + gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks) + title is tab text displayed to user in the UI + elem_id is HTML id for the tab + """ + callbacks_ui_tabs.append(callback) + diff --git a/modules/scripts.py b/modules/scripts.py index 1039fa9c..65f25f49 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,86 +1,153 @@ import os import sys import traceback +from collections import namedtuple import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared +from modules import shared, paths, script_callbacks + +AlwaysVisible = object() + class Script: filename = None args_from = None args_to = None + alwayson = False + + infotext_fields = None + """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when + parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example + """ - # The title of the script. This is what will be displayed in the dropdown menu. def title(self): + """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" + raise NotImplementedError() - # How the script is displayed in the UI. See https://gradio.app/docs/#components - # for the different UI components you can use and how to create them. - # Most UI components can return a value, such as a boolean for a checkbox. - # The returned values are passed to the run method as parameters. def ui(self, is_img2img): + """this function should create gradio UI elements. See https://gradio.app/docs/#components + The return value should be an array of all components that are used in processing. + Values of those returned componenbts will be passed to run() and process() functions. + """ + pass - # Determines when the script should be shown in the dropdown menu via the - # returned value. As an example: - # is_img2img is True if the current tab is img2img, and False if it is txt2img. - # Thus, return is_img2img to only show the script on the img2img tab. def show(self, is_img2img): + """ + is_img2img is True if this function is called for the img2img interface, and Fasle otherwise + + This function should return: + - False if the script should not be shown in UI at all + - True if the script should be shown in UI if it's scelected in the scripts drowpdown + - script.AlwaysVisible if the script should be shown in UI at all times + """ + return True - # This is where the additional processing is implemented. The parameters include - # self, the model object "p" (a StableDiffusionProcessing class, see - # processing.py), and the parameters returned by the ui method. - # Custom functions can be defined here, and additional libraries can be imported - # to be used in processing. The return value should be a Processed object, which is - # what is returned by the process_images method. - def run(self, *args): + def run(self, p, *args): + """ + This function is called if the script has been selected in the script dropdown. + It must do all processing and return the Processed object with results, same as + one returned by processing.process_images. + + Usually the processing is done by calling the processing.process_images function. + + args contains all values returned by components from ui() + """ + raise NotImplementedError() - # The description method is currently unused. - # To add a description that appears when hovering over the title, amend the "titles" - # dict in script.js to include the script title (returned by title) as a key, and - # your description as the value. + def process(self, p, *args): + """ + This function is called before processing begins for AlwaysVisible scripts. + scripts. You can modify the processing object (p) here, inject hooks, etc. + """ + + pass + def describe(self): + """unused""" return "" +current_basedir = paths.script_path + + +def basedir(): + """returns the base directory for the current script. For scripts in the main scripts directory, + this is the main directory (where webui.py resides), and for scripts in extensions directory + (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic) + """ + return current_basedir + + scripts_data = [] +ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) +ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"]) + + +def list_scripts(scriptdirname, extension): + scripts_list = [] + + basedir = os.path.join(paths.script_path, scriptdirname) + if os.path.exists(basedir): + for filename in sorted(os.listdir(basedir)): + scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) + + extdir = os.path.join(paths.script_path, "extensions") + if os.path.exists(extdir): + for dirname in sorted(os.listdir(extdir)): + dirpath = os.path.join(extdir, dirname) + if not os.path.isdir(dirpath): + continue + for filename in sorted(os.listdir(os.path.join(dirpath, scriptdirname))): + scripts_list.append(ScriptFile(dirpath, filename, os.path.join(dirpath, scriptdirname, filename))) -def load_scripts(basedir): - if not os.path.exists(basedir): - return + scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] - for filename in sorted(os.listdir(basedir)): - path = os.path.join(basedir, filename) + return scripts_list - if os.path.splitext(path)[1].lower() != '.py': - continue - if not os.path.isfile(path): - continue +def load_scripts(): + global current_basedir + scripts_data.clear() + script_callbacks.clear_callbacks() + + scripts_list = list_scripts("scripts", ".py") + + syspath = sys.path + for scriptfile in sorted(scripts_list): try: - with open(path, "r", encoding="utf8") as file: + if scriptfile.basedir != paths.script_path: + sys.path = [scriptfile.basedir] + sys.path + current_basedir = scriptfile.basedir + + with open(scriptfile.path, "r", encoding="utf8") as file: text = file.read() from types import ModuleType - compiled = compile(text, path, 'exec') - module = ModuleType(filename) + compiled = compile(text, scriptfile.path, 'exec') + module = ModuleType(scriptfile.filename) exec(compiled, module.__dict__) for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append((script_class, path)) + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir)) except Exception: - print(f"Error loading script: {filename}", file=sys.stderr) + print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + finally: + sys.path = syspath + current_basedir = paths.script_path + def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: @@ -96,56 +163,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): class ScriptRunner: def __init__(self): self.scripts = [] + self.selectable_scripts = [] + self.alwayson_scripts = [] self.titles = [] + self.infotext_fields = [] def setup_ui(self, is_img2img): - for script_class, path in scripts_data: + for script_class, path, basedir in scripts_data: script = script_class() script.filename = path - if not script.show(is_img2img): - continue + visibility = script.show(is_img2img) - self.scripts.append(script) + if visibility == AlwaysVisible: + self.scripts.append(script) + self.alwayson_scripts.append(script) + script.alwayson = True - self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] + elif visibility: + self.scripts.append(script) + self.selectable_scripts.append(script) - dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") - dropdown.save_to_config = True - inputs = [dropdown] + self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] + + inputs = [None] + inputs_alwayson = [True] - for script in self.scripts: + def create_script_ui(script, inputs, inputs_alwayson): script.args_from = len(inputs) script.args_to = len(inputs) controls = wrap_call(script.ui, script.filename, "ui", is_img2img) if controls is None: - continue + return for control in controls: control.custom_script_source = os.path.basename(script.filename) - control.visible = False + if not script.alwayson: + control.visible = False + + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields inputs += controls + inputs_alwayson += [script.alwayson for _ in controls] script.args_to = len(inputs) + for script in self.alwayson_scripts: + with gr.Group(): + create_script_ui(script, inputs, inputs_alwayson) + + dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") + dropdown.save_to_config = True + inputs[0] = dropdown + + for script in self.selectable_scripts: + create_script_ui(script, inputs, inputs_alwayson) + def select_script(script_index): - if 0 < script_index <= len(self.scripts): - script = self.scripts[script_index-1] + if 0 < script_index <= len(self.selectable_scripts): + script = self.selectable_scripts[script_index-1] args_from = script.args_from args_to = script.args_to else: args_from = 0 args_to = 0 - return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] + return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)] def init_field(title): if title == 'None': return script_index = self.titles.index(title) - script = self.scripts[script_index] + script = self.selectable_scripts[script_index] for i in range(script.args_from, script.args_to): inputs[i].visible = True @@ -164,7 +255,7 @@ class ScriptRunner: if script_index == 0: return None - script = self.scripts[script_index-1] + script = self.selectable_scripts[script_index-1] if script is None: return None @@ -176,6 +267,15 @@ class ScriptRunner: return processed + def run_alwayson_scripts(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.process(p, *script_args) + except Exception: + print(f"Error running alwayson script: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def reload_sources(self): for si, script in list(enumerate(self.scripts)): with open(script.filename, "r", encoding="utf8") as file: @@ -197,19 +297,21 @@ class ScriptRunner: self.scripts[si].args_from = args_from self.scripts[si].args_to = args_to + scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + def reload_script_body_only(): scripts_txt2img.reload_sources() scripts_img2img.reload_sources() -def reload_scripts(basedir): +def reload_scripts(): global scripts_txt2img, scripts_img2img - scripts_data.clear() - load_scripts(basedir) + load_scripts() scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 1f8587d1..0f10828e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -332,7 +332,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers.append([1.0] * 75) z1 = self.process_tokens(tokens, multipliers) - z1 = shared.aesthetic_clip(z1, remade_batch_tokens) z = z1 if z is None else torch.cat((z, z1), axis=-2) remade_batch_tokens = rem_tokens diff --git a/modules/sd_models.py b/modules/sd_models.py index d99dbce8..f9b3063d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -7,7 +7,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices +from modules import shared, modelloader, devices, script_callbacks from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -238,6 +238,9 @@ def load_model(checkpoint_info=None): sd_hijack.model_hijack.hijack(sd_model) sd_model.eval() + shared.sd_model = sd_model + + script_callbacks.model_loaded_callback(sd_model) print(f"Model loaded.") return sd_model @@ -252,7 +255,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model(checkpoint_info) + load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/modules/shared.py b/modules/shared.py index 0dbe360d..7d786f07 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -31,7 +31,6 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") -parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(models_path, 'aesthetic_embeddings'), help="aesthetic_embeddings directory(default: aesthetic_embeddings)") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") @@ -109,21 +108,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None - -os.makedirs(cmd_opts.aesthetic_embeddings_dir, exist_ok=True) -aesthetic_embeddings = {} - - -def update_aesthetic_embeddings(): - global aesthetic_embeddings - aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in - os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} - aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings) - - -update_aesthetic_embeddings() - - def reload_hypernetworks(): global hypernetworks @@ -415,9 +399,6 @@ sd_model = None clip_model = None -from modules.aesthetic_clip import AestheticCLIP -aesthetic_clip = AestheticCLIP() - progress_print_out = sys.stdout diff --git a/modules/txt2img.py b/modules/txt2img.py index 1761cfa2..c9d5a090 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -7,7 +7,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -36,7 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: firstphase_height=firstphase_height if enable_hr else None, ) - shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/ui.py b/modules/ui.py index 70a9cf10..c977482c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,10 +23,10 @@ import gradio as gr import gradio.utils import gradio.routes -from modules import sd_hijack, sd_models, localization +from modules import sd_hijack, sd_models, localization, script_callbacks from modules.paths import script_path -from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings +from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags @@ -44,7 +44,6 @@ from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.aesthetic_clip as aesthetic_clip import modules.images_history as img_his @@ -662,8 +661,6 @@ def create_ui(wrap_gradio_gpu_call): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui() - with gr.Group(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) @@ -718,14 +715,6 @@ def create_ui(wrap_gradio_gpu_call): denoising_strength, firstphase_width, firstphase_height, - aesthetic_lr, - aesthetic_weight, - aesthetic_steps, - aesthetic_imgs, - aesthetic_slerp, - aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative ] + custom_inputs, outputs=[ @@ -804,14 +793,7 @@ def create_ui(wrap_gradio_gpu_call): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), - (aesthetic_lr, "Aesthetic LR"), - (aesthetic_weight, "Aesthetic weight"), - (aesthetic_steps, "Aesthetic steps"), - (aesthetic_imgs, "Aesthetic embedding"), - (aesthetic_slerp, "Aesthetic slerp"), - (aesthetic_imgs_text, "Aesthetic text"), - (aesthetic_text_negative, "Aesthetic text negative"), - (aesthetic_slerp_angle, "Aesthetic slerp angle"), + *modules.scripts.scripts_txt2img.infotext_fields ] txt2img_preview_params = [ @@ -896,8 +878,6 @@ def create_ui(wrap_gradio_gpu_call): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui() - with gr.Group(): custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) @@ -988,14 +968,6 @@ def create_ui(wrap_gradio_gpu_call): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, - aesthetic_lr_im, - aesthetic_weight_im, - aesthetic_steps_im, - aesthetic_imgs_im, - aesthetic_slerp_im, - aesthetic_imgs_text_im, - aesthetic_slerp_angle_im, - aesthetic_text_negative_im, ] + custom_inputs, outputs=[ img2img_gallery, @@ -1087,14 +1059,7 @@ def create_ui(wrap_gradio_gpu_call): (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), (denoising_strength, "Denoising strength"), - (aesthetic_lr_im, "Aesthetic LR"), - (aesthetic_weight_im, "Aesthetic weight"), - (aesthetic_steps_im, "Aesthetic steps"), - (aesthetic_imgs_im, "Aesthetic embedding"), - (aesthetic_slerp_im, "Aesthetic slerp"), - (aesthetic_imgs_text_im, "Aesthetic text"), - (aesthetic_text_negative_im, "Aesthetic text negative"), - (aesthetic_slerp_angle_im, "Aesthetic slerp angle"), + *modules.scripts.scripts_img2img.infotext_fields ] token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) @@ -1217,9 +1182,9 @@ def create_ui(wrap_gradio_gpu_call): ) #images history images_history_switch_dict = { - "fn":modules.generation_parameters_copypaste.connect_paste, - "t2i":txt2img_paste_fields, - "i2i":img2img_paste_fields + "fn": modules.generation_parameters_copypaste.connect_paste, + "t2i": txt2img_paste_fields, + "i2i": img2img_paste_fields } images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) @@ -1264,18 +1229,6 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create embedding", variant='primary') - with gr.Tab(label="Create aesthetic images embedding"): - - new_embedding_name_ae = gr.Textbox(label="Name") - process_src_ae = gr.Textbox(label='Source directory') - batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256) - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding_ae = gr.Button(value="Create images embedding", variant='primary') - with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) @@ -1375,21 +1328,6 @@ def create_ui(wrap_gradio_gpu_call): ] ) - create_embedding_ae.click( - fn=aesthetic_clip.generate_imgs_embd, - inputs=[ - new_embedding_name_ae, - process_src_ae, - batch_ae - ], - outputs=[ - aesthetic_imgs, - aesthetic_imgs_im, - ti_output, - ti_outcome, - ] - ) - create_hypernetwork.click( fn=modules.hypernetworks.ui.create_hypernetwork, inputs=[ @@ -1580,10 +1518,10 @@ Requested path was: {f} if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() + oldval = opts.data.get(key, None) if cmd_opts.hide_ui_dir_config and key in restricted_opts: return gr.update(value=oldval), opts.dumpjson() - oldval = opts.data.get(key, None) opts.data[key] = value if oldval != value: @@ -1692,9 +1630,12 @@ Requested path was: {f} (images_history, "Image Browser", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), - (settings_interface, "Settings", "settings"), ] + interfaces += script_callbacks.ui_tabs_callback() + + interfaces += [(settings_interface, "Settings", "settings")] + with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: css = file.read() diff --git a/webui.py b/webui.py index 87589064..b1deca1b 100644 --- a/webui.py +++ b/webui.py @@ -71,6 +71,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) + def initialize(): modelloader.cleanup_models() modules.sd_models.setup_model() @@ -79,9 +80,9 @@ def initialize(): shared.face_restorers.append(modules.face_restoration.FaceRestoration()) modelloader.load_upscalers() - modules.scripts.load_scripts(os.path.join(script_path, "scripts")) + modules.scripts.load_scripts() - shared.sd_model = modules.sd_models.load_model() + modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) @@ -145,7 +146,7 @@ def webui(): sd_samplers.set_samplers() print('Reloading Custom Scripts') - modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) + modules.scripts.reload_scripts() print('Reloading modules: modules.ui') importlib.reload(modules.ui) print('Refreshing Model List') -- cgit v1.2.3 From 6398dc9b1049f242576ca309f95a3fb1e654951c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 13:34:49 +0300 Subject: further support for extensions --- .gitignore | 1 + README.md | 3 +-- modules/scripts.py | 44 +++++++++++++++++++++++++++++++++++--------- modules/ui.py | 19 ++++++++++--------- style.css | 2 +- 5 files changed, 48 insertions(+), 21 deletions(-) (limited to 'modules') diff --git a/.gitignore b/.gitignore index 2f1e08ed..8fa05852 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ notification.mp3 /SwinIR /textual_inversion .vscode +/extensions diff --git a/README.md b/README.md index 5b5dc8ba..6853aea0 100644 --- a/README.md +++ b/README.md @@ -83,8 +83,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - Estimated completion time in progress bar - API - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. -- Aesthetic Gradients, a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - +- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. diff --git a/modules/scripts.py b/modules/scripts.py index 65f25f49..9323af3e 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -102,17 +102,39 @@ def list_scripts(scriptdirname, extension): if os.path.exists(extdir): for dirname in sorted(os.listdir(extdir)): dirpath = os.path.join(extdir, dirname) - if not os.path.isdir(dirpath): + scriptdirpath = os.path.join(dirpath, scriptdirname) + + if not os.path.isdir(scriptdirpath): continue - for filename in sorted(os.listdir(os.path.join(dirpath, scriptdirname))): - scripts_list.append(ScriptFile(dirpath, filename, os.path.join(dirpath, scriptdirname, filename))) + for filename in sorted(os.listdir(scriptdirpath)): + scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename))) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] return scripts_list +def list_files_with_name(filename): + res = [] + + dirs = [paths.script_path] + + extdir = os.path.join(paths.script_path, "extensions") + if os.path.exists(extdir): + dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))] + + for dirpath in dirs: + if not os.path.isdir(dirpath): + continue + + path = os.path.join(dirpath, filename) + if os.path.isfile(filename): + res.append(path) + + return res + + def load_scripts(): global current_basedir scripts_data.clear() @@ -276,7 +298,7 @@ class ScriptRunner: print(f"Error running alwayson script: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - def reload_sources(self): + def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): with open(script.filename, "r", encoding="utf8") as file: args_from = script.args_from @@ -286,9 +308,12 @@ class ScriptRunner: from types import ModuleType - compiled = compile(text, filename, 'exec') - module = ModuleType(script.filename) - exec(compiled, module.__dict__) + module = cache.get(filename, None) + if module is None: + compiled = compile(text, filename, 'exec') + module = ModuleType(script.filename) + exec(compiled, module.__dict__) + cache[filename] = module for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): @@ -303,8 +328,9 @@ scripts_img2img = ScriptRunner() def reload_script_body_only(): - scripts_txt2img.reload_sources() - scripts_img2img.reload_sources() + cache = {} + scripts_txt2img.reload_sources(cache) + scripts_img2img.reload_sources(cache) def reload_scripts(): diff --git a/modules/ui.py b/modules/ui.py index c977482c..29986124 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1636,13 +1636,15 @@ Requested path was: {f} interfaces += [(settings_interface, "Settings", "settings")] - with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: - css = file.read() + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" if os.path.exists(os.path.join(script_path, "user.css")): with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - usercss = file.read() - css += usercss + css += file.read() + "\n" if not cmd_opts.no_progressbar_hiding: css += css_hide_progressbar @@ -1865,9 +1867,9 @@ def load_javascript(raw_response): with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: javascript = f'' - jsdir = os.path.join(script_path, "javascript") - for filename in sorted(os.listdir(jsdir)): - with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: + scripts_list = modules.scripts.list_scripts("javascript", ".js") + for basedir, filename, path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: javascript += f"\n" if cmd_opts.theme is not None: @@ -1885,6 +1887,5 @@ def load_javascript(raw_response): gradio.routes.templates.TemplateResponse = template_response -reload_javascript = partial(load_javascript, - gradio.routes.templates.TemplateResponse) +reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) reload_javascript() diff --git a/style.css b/style.css index 5d2bacc9..26ae36a5 100644 --- a/style.css +++ b/style.css @@ -477,7 +477,7 @@ input[type="range"]{ padding: 0; } -#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_aesthetic_embeddings{ +#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{ max-width: 2.5em; min-width: 2.5em; height: 2.4em; -- cgit v1.2.3 From 50b5504401e50b6c94eba41b37fe212b2f27b792 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 14:04:14 +0300 Subject: remove parsing command line from devices.py --- modules/devices.py | 14 +++++--------- modules/lowvram.py | 9 ++++----- 2 files changed, 9 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 8a159282..dc1f3cdd 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -15,14 +15,10 @@ def extract_device_id(args, name): def get_optimal_device(): if torch.cuda.is_available(): - # CUDA device selection support: - if "shared" not in sys.modules: - commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop. - sys.argv += shlex.split(commandline_args) - device_id = extract_device_id(sys.argv, '--device-id') - else: - device_id = shared.cmd_opts.device_id - + from modules import shared + + device_id = shared.cmd_opts.device_id + if device_id is not None: cuda_device = f"cuda:{device_id}" return torch.device(cuda_device) @@ -49,7 +45,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() +device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 diff --git a/modules/lowvram.py b/modules/lowvram.py index 7eba1349..f327c3df 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -1,9 +1,8 @@ import torch -from modules.devices import get_optimal_device +from modules import devices module_in_gpu = None cpu = torch.device("cpu") -device = gpu = get_optimal_device() def send_everything_to_cpu(): @@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram): if module_in_gpu is not None: module_in_gpu.to(cpu) - module.to(gpu) + module.to(devices.device) module_in_gpu = module # see below for register_forward_pre_hook; @@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram): # send the model to GPU. Then put modules back. the modules will be in CPU. stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None - sd_model.to(device) + sd_model.to(devices.device) sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored # register hooks for those the first two models @@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram): # so that only one of them is in GPU at a time stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None - sd_model.model.to(device) + sd_model.model.to(devices.device) diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored # install hooks for bits of third model -- cgit v1.2.3 From 0e8ca8e7af05be22d7d2c07a47c3c7febe0f0ab6 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Sat, 22 Oct 2022 11:07:00 +0000 Subject: add dropout --- modules/hypernetworks/hypernetwork.py | 68 +++++++++++++++++++++-------------- modules/hypernetworks/ui.py | 10 +++--- modules/ui.py | 43 +++++++++++----------- 3 files changed, 70 insertions(+), 51 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 905cbeef..e493f366 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -1,47 +1,60 @@ +import csv import datetime import glob import html import os import sys import traceback -import tqdm -import csv +import modules.textual_inversion.dataset import torch - -from ldm.util import default -from modules import devices, shared, processing, sd_models -import torch -from torch import einsum +import tqdm from einops import rearrange, repeat -import modules.textual_inversion.dataset +from ldm.util import default +from modules import devices, processing, sd_models, shared from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler +from torch import einsum class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU, - "swish": torch.nn.Hardswish} - - def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): + activation_dict = { + "relu": torch.nn.ReLU, + "leakyrelu": torch.nn.LeakyReLU, + "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish, + } + + def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): super().__init__() assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - + assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'" + linears = [] for i in range(len(layer_structure) - 1): + + # Add a fully-connected layer linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - # if skip_first_layer because first parameters potentially contain negative values - # if i < 1: continue - if activation_func in HypernetworkModule.activation_dict: - linears.append(HypernetworkModule.activation_dict[activation_func]()) + + # Add an activation func + if activation_func == "linear": + pass + elif activation_func in self.activation_dict: + linears.append(self.activation_dict[activation_func]()) else: - print("Invalid key {} encountered as activation function!".format(activation_func)) - # if use_dropout: - # linears.append(torch.nn.Dropout(p=0.3)) + raise NotImplementedError( + "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'" + ) + + # Add dropout + if use_dropout: + linears.append(torch.nn.Dropout(p=0.3)) + + # Add layer normalization if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -93,7 +106,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): self.filename = None self.name = name self.layers = {} @@ -101,13 +114,14 @@ class Hypernetwork: self.sd_checkpoint = None self.sd_checkpoint_name = None self.layer_structure = layer_structure - self.add_layer_norm = add_layer_norm self.activation_func = activation_func + self.add_layer_norm = add_layer_norm + self.use_dropout = use_dropout for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), ) def weights(self): @@ -129,8 +143,9 @@ class Hypernetwork: state_dict['step'] = self.step state_dict['name'] = self.name state_dict['layer_structure'] = self.layer_structure - state_dict['is_layer_norm'] = self.add_layer_norm state_dict['activation_func'] = self.activation_func + state_dict['is_layer_norm'] = self.add_layer_norm + state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name @@ -144,8 +159,9 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) - self.add_layer_norm = state_dict.get('is_layer_norm', False) self.activation_func = state_dict.get('activation_func', None) + self.add_layer_norm = state_dict.get('is_layer_norm', False) + self.use_dropout = state_dict.get('use_dropout', False) for size, sd in state_dict.items(): if type(size) == int: diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 1a5a27d8..5f6f17b6 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -3,14 +3,13 @@ import os import re import gradio as gr - -import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess -from modules import sd_hijack, shared, devices +import modules.textual_inversion.textual_inversion +from modules import devices, sd_hijack, shared from modules.hypernetworks import hypernetwork -def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None): +def create_hypernetwork(name, enable_sizes, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" @@ -21,8 +20,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm name=name, enable_sizes=[int(x) for x in enable_sizes], layer_structure=layer_structure, - add_layer_norm=add_layer_norm, activation_func=activation_func, + add_layer_norm=add_layer_norm, + use_dropout=use_dropout, ) hypernet.save(fn) diff --git a/modules/ui.py b/modules/ui.py index 716f14b8..d4b32c05 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,43 +5,44 @@ import json import math import mimetypes import os +import platform import random +import subprocess as sp import sys import tempfile import time import traceback -import platform -import subprocess as sp from functools import partial, reduce +import gradio as gr +import gradio.routes +import gradio.utils import numpy as np +import piexif import torch from PIL import Image, PngImagePlugin -import piexif -import gradio as gr -import gradio.utils -import gradio.routes - -from modules import sd_hijack, sd_models, localization +from modules import localization, sd_hijack, sd_models from modules.paths import script_path -from modules.shared import opts, cmd_opts, restricted_opts +from modules.shared import cmd_opts, opts, restricted_opts + if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags -import modules.shared as shared -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.sd_hijack import model_hijack + +import modules.codeformer_model +import modules.generation_parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.images_history as img_his import modules.ldsr_model import modules.scripts -import modules.gfpgan_model -import modules.codeformer_model +import modules.shared as shared import modules.styles -import modules.generation_parameters_copypaste +import modules.textual_inversion.ui from modules import prompt_parser from modules.images import save_image -import modules.textual_inversion.ui -import modules.hypernetworks.ui -import modules.images_history as img_his +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -1223,8 +1224,9 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") + new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"]) new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") - new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"]) + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") with gr.Row(): with gr.Column(scale=3): @@ -1308,8 +1310,9 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_name, new_hypernetwork_sizes, new_hypernetwork_layer_structure, - new_hypernetwork_add_layer_norm, new_hypernetwork_activation_func, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From 1cd3ed7def40198f46d30f74dd37d2906ebdbaa6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 14:28:56 +0300 Subject: fix for extensions without style.css --- modules/ui.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 29986124..d8d52db1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1639,6 +1639,9 @@ Requested path was: {f} css = "" for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + with open(cssfile, "r", encoding="utf8") as file: css += file.read() + "\n" -- cgit v1.2.3 From 7fd90128eb6d1820045bfe2c2c1269661023a712 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 14:48:43 +0300 Subject: added a guard for hypernet training that will stop early if weights are getting no gradients --- modules/hypernetworks/hypernetwork.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 47d91ea5..46039a49 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -310,6 +310,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) + steps_without_grad = 0 + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: hypernetwork.step = i + ititial_step @@ -332,8 +334,17 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item() optimizer.zero_grad() + weights[0].grad = None loss.backward() + + if weights[0].grad is None: + steps_without_grad += 1 + else: + steps_without_grad = 0 + assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' + optimizer.step() + mean_loss = losses.mean() if torch.isnan(mean_loss): raise RuntimeError("Loss diverged.") -- cgit v1.2.3 From fccba4729db341a299db3343e3264fecd9459a07 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Sat, 22 Oct 2022 12:02:41 +0000 Subject: add an option to avoid dying relu --- modules/hypernetworks/hypernetwork.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b7a04038..3132a56c 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -32,7 +32,6 @@ class HypernetworkModule(torch.nn.Module): assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'" linears = [] for i in range(len(layer_structure) - 1): @@ -43,12 +42,13 @@ class HypernetworkModule(torch.nn.Module): # Add an activation func if activation_func == "linear" or activation_func is None: pass + # If ReLU, Skip adding it to the first layer to avoid dying ReLU + elif activation_func == "relu" and i < 1: + pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) else: - raise RuntimeError( - "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'" - ) + raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') # Add dropout if use_dropout: @@ -166,8 +166,8 @@ class Hypernetwork: for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func), - HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func), + HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), ) self.name = state_dict.get('name', self.name) -- cgit v1.2.3 From 7912acef725832debef58c4c7bf8ec22fb446c0b Mon Sep 17 00:00:00 2001 From: discus0434 Date: Sat, 22 Oct 2022 13:00:44 +0000 Subject: small fix --- modules/hypernetworks/hypernetwork.py | 12 +++++------- modules/ui.py | 1 - 2 files changed, 5 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3132a56c..7d12e0ff 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -42,22 +42,20 @@ class HypernetworkModule(torch.nn.Module): # Add an activation func if activation_func == "linear" or activation_func is None: pass - # If ReLU, Skip adding it to the first layer to avoid dying ReLU - elif activation_func == "relu" and i < 1: - pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) else: raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') - # Add dropout - if use_dropout: - linears.append(torch.nn.Dropout(p=0.3)) - # Add layer normalization if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) + # Add dropout + if use_dropout: + p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2 + linears.append(torch.nn.Dropout(p=p)) + self.linear = torch.nn.Sequential(*linears) if state_dict is not None: diff --git a/modules/ui.py b/modules/ui.py index cd118552..eca887ca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1244,7 +1244,6 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") - new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"]) with gr.Row(): with gr.Column(scale=3): -- cgit v1.2.3 From 6a4fa73a38935a18779ce1809892730fd1572bee Mon Sep 17 00:00:00 2001 From: discus0434 Date: Sat, 22 Oct 2022 13:44:39 +0000 Subject: small fix --- modules/hypernetworks/hypernetwork.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3372aae2..3bc71ee5 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -51,10 +51,9 @@ class HypernetworkModule(torch.nn.Module): if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - # Add dropout - if use_dropout: - p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2 - linears.append(torch.nn.Dropout(p=p)) + # Add dropout expect last layer + if use_dropout and i < len(layer_structure) - 3: + linears.append(torch.nn.Dropout(p=0.3)) self.linear = torch.nn.Sequential(*linears) -- cgit v1.2.3 From d37cfffd537cd29309afbcb192c4f979995c6a34 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 19:18:56 +0300 Subject: added callback for creating new settings in extensions --- modules/script_callbacks.py | 11 +++++++++++ modules/shared.py | 19 +++++++++++++++++-- modules/ui.py | 6 +++++- 3 files changed, 33 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 866b7acd..1270e50f 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,6 +1,7 @@ callbacks_model_loaded = [] callbacks_ui_tabs = [] +callbacks_ui_settings = [] def clear_callbacks(): @@ -22,6 +23,11 @@ def ui_tabs_callback(): return res +def ui_settings_callback(): + for callback in callbacks_ui_settings: + callback() + + def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" @@ -40,3 +46,8 @@ def on_ui_tabs(callback): """ callbacks_ui_tabs.append(callback) + +def on_ui_settings(callback): + """register a function to be called before UI settingsare populated; add your settings + by using shared.opts.add_option(shared.OptionInfo(...)) """ + callbacks_ui_settings.append(callback) diff --git a/modules/shared.py b/modules/shared.py index 5d83971e..d9cb65ef 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -165,13 +165,13 @@ def realesrgan_models_names(): class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): self.default = default self.label = label self.component = component self.component_args = component_args self.onchange = onchange - self.section = None + self.section = section self.refresh = refresh @@ -327,6 +327,7 @@ options_templates.update(options_section(('images-history', "Images Browser"), { })) + class Options: data = None data_labels = options_templates @@ -389,6 +390,20 @@ class Options: d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} return json.dumps(d) + def add_option(self, key, info): + self.data_labels[key] = info + + def reorder(self): + """reorder settings so that all items related to section always go together""" + + section_ids = {} + settings_items = self.data_labels.items() + for k, item in settings_items: + if item.section not in section_ids: + section_ids[item.section] = len(section_ids) + + self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])} + opts = Options() if os.path.exists(config_filename): diff --git a/modules/ui.py b/modules/ui.py index d8d52db1..2849b111 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1461,6 +1461,9 @@ def create_ui(wrap_gradio_gpu_call): components = [] component_dict = {} + script_callbacks.ui_settings_callback() + opts.reorder() + def open_folder(f): if not os.path.exists(f): print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') @@ -1564,7 +1567,8 @@ Requested path was: {f} previous_section = item.section - gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='

{}

'.format(item.section[1])) + elem_id, text = item.section + gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) if k in quicksettings_names: quicksettings_list.append((i, k, item)) -- cgit v1.2.3 From dbc8ab65f6d496459a76547776b656c96ad1350d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 19:19:17 +0300 Subject: typo --- modules/script_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 1270e50f..5bcccd67 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -48,6 +48,6 @@ def on_ui_tabs(callback): def on_ui_settings(callback): - """register a function to be called before UI settingsare populated; add your settings + """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ callbacks_ui_settings.append(callback) -- cgit v1.2.3 From 72383abacdc6a101704a6f73758ce4d0bb68c9d1 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Sat, 22 Oct 2022 16:50:07 +0200 Subject: Deepdanbooru linux fix --- modules/deepbooru.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 8914662d..3c34ab7c 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -50,7 +50,8 @@ def create_deepbooru_process(threshold, deepbooru_opts): the tags. """ from modules import shared # prevents circular reference - shared.deepbooru_process_manager = multiprocessing.Manager() + context = multiprocessing.get_context("spawn") + shared.deepbooru_process_manager = context.Manager() shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() shared.deepbooru_process_return["value"] = -1 -- cgit v1.2.3 From e38625011cd4955da4bc67fe95d1d0f4c0c53899 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Sat, 22 Oct 2022 16:56:52 +0200 Subject: fix part2 --- modules/deepbooru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 3c34ab7c..8bbc90a4 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -55,7 +55,7 @@ def create_deepbooru_process(threshold, deepbooru_opts): shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts)) + shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts)) shared.deepbooru_process.start() -- cgit v1.2.3 From 324c7c732dd9afc3d4c397c354797ae5d655b514 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 20:09:37 +0300 Subject: record First pass size as 0x0 for #3328 --- modules/processing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 372489f7..27c669b0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -524,6 +524,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 + self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" + if self.firstphase_width == 0 or self.firstphase_height == 0: desired_pixel_count = 512 * 512 actual_pixel_count = self.width * self.height @@ -545,7 +547,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): firstphase_width_truncated = self.firstphase_height * self.width / self.height firstphase_height_truncated = self.firstphase_height - self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f -- cgit v1.2.3 From 0df94d3fcf9d1fc47c4d39039352a3d5b3380c1f Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Sat, 22 Oct 2022 12:59:21 -0400 Subject: fix aesthetic gradients doing nothing after loading a different model --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index f9b3063d..49dc3238 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -236,12 +236,11 @@ def load_model(checkpoint_info=None): sd_model.to(shared.device) sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) sd_model.eval() shared.sd_model = sd_model - script_callbacks.model_loaded_callback(sd_model) - print(f"Model loaded.") return sd_model @@ -268,6 +267,7 @@ def reload_model_weights(sd_model, info=None): load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) -- cgit v1.2.3 From 321bacc6a9eaf4a25f31279f288fa752be507a20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 20:15:12 +0300 Subject: call model_loaded_callback after setting shared.sd_model in case scripts refer to it using that --- modules/sd_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 49dc3238..e697bb72 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -236,11 +236,12 @@ def load_model(checkpoint_info=None): sd_model.to(shared.device) sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) sd_model.eval() shared.sd_model = sd_model + script_callbacks.model_loaded_callback(sd_model) + print(f"Model loaded.") return sd_model -- cgit v1.2.3 From 24694e5983d0944b901892cb101878e6dec89a20 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 01:57:58 +0900 Subject: Update hypernetwork.py --- modules/hypernetworks/hypernetwork.py | 55 ++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3bc71ee5..81132be4 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum +from statistics import stdev, mean class HypernetworkModule(torch.nn.Module): multiplier = 1.0 @@ -268,6 +269,32 @@ def stack_conds(conds): return torch.stack(conds) +def log_statistics(loss_info:dict, key, value): + if key not in loss_info: + loss_info[key] = [value] + else: + loss_info[key].append(value) + if len(loss_info) > 1024: + loss_info.pop(0) + + +def statistics(data): + total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})" + recent_data = data[-32:] + recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})" + return total_information, recent_information + + +def report_statistics(loss_info:dict): + keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) + for key in keys: + info, recent = statistics(loss_info[key]) + print("Loss statistics for file " + key) + print(info) + print(recent) + + + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -310,7 +337,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log for weight in weights: weight.requires_grad = True - losses = torch.zeros((32,)) + size = len(ds.indexes) + loss_dict = {} + losses = torch.zeros((size,)) + previous_mean_loss = 0 + print("Mean loss of {} elements".format(size)) last_saved_file = "" last_saved_image = "" @@ -329,7 +360,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: hypernetwork.step = i + ititial_step - + if loss_dict and i % size == 0: + previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict) + scheduler.apply(optimizer, hypernetwork.step) if scheduler.finished: break @@ -346,7 +379,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log del c losses[hypernetwork.step % losses.shape[0]] = loss.item() - + for entry in entries: + log_statistics(loss_dict, entry.filename, loss.item()) + optimizer.zero_grad() weights[0].grad = None loss.backward() @@ -359,10 +394,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log optimizer.step() - mean_loss = losses.mean() - if torch.isnan(mean_loss): + if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): raise RuntimeError("Loss diverged.") - pbar.set_description(f"loss: {mean_loss:.7f}") + pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}") if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. @@ -371,7 +405,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log hypernetwork.save(last_saved_file) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { - "loss": f"{mean_loss:.7f}", + "loss": f"{previous_mean_loss:.7f}", "learn_rate": scheduler.learn_rate }) @@ -420,14 +454,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = f"""

-Loss: {mean_loss:.7f}
+Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
Last prompt: {html.escape(entries[0].cond_text)}
Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - + + report_statistics(loss_dict) checkpoint = sd_models.select_checkpoint() hypernetwork.sd_checkpoint = checkpoint.hash @@ -438,5 +473,3 @@ Last saved image: {html.escape(last_saved_image)}
hypernetwork.save(filename) return hypernetwork, filename - - -- cgit v1.2.3 From 4fdb53c1e9962507fc8336dad9a0fabfe6c418c0 Mon Sep 17 00:00:00 2001 From: Unnoen Date: Wed, 19 Oct 2022 21:38:10 +1100 Subject: Generate grid preview for progress image --- modules/sd_samplers.py | 26 +++++++++++++++++++++++++- modules/shared.py | 1 + modules/ui.py | 5 ++++- 3 files changed, 30 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index f58a29b9..74a480e5 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,7 +7,7 @@ import inspect import k_diffusion.sampling import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from modules import prompt_parser, devices, processing +from modules import prompt_parser, devices, processing, images from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -89,6 +89,30 @@ def sample_to_image(samples): x_sample = x_sample.astype(np.uint8) return Image.fromarray(x_sample) +def samples_to_image_grid(samples): + progress_images = [] + for i in range(len(samples)): + # Decode the samples individually to reduce VRAM usage at the cost of a bit of speed. + x_sample = processing.decode_first_stage(shared.sd_model, samples[i:i+1])[0] + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + progress_images.append(Image.fromarray(x_sample)) + + return images.image_grid(progress_images) + +def samples_to_image_grid_combined(samples): + progress_images = [] + # Decode all samples at once to increase speed at the cost of VRAM usage. + x_samples = processing.decode_first_stage(shared.sd_model, samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + progress_images.append(Image.fromarray(x_sample)) + + return images.image_grid(progress_images) def store_latent(decoded): state.current_latent = decoded diff --git a/modules/shared.py b/modules/shared.py index d9cb65ef..95d6e225 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -294,6 +294,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), + "progress_decode_combined": OptionInfo(False, "Decode all progress images at once. (Slighty speeds up progress generation but consumes significantly more VRAM with large batches.)"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), diff --git a/modules/ui.py b/modules/ui.py index 56c233ab..de0abc7e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -318,7 +318,10 @@ def check_progress_call(id_part): if shared.parallel_processing_allowed: if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None: - shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent) + if opts.progress_decode_combined: + shared.state.current_image = modules.sd_samplers.samples_to_image_grid_combined(shared.state.current_latent) + else: + shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent) shared.state.current_image_sampling_step = shared.state.sampling_step image = shared.state.current_image -- cgit v1.2.3 From d213d6ca6f90094cb45c11e2f3cb37d25a8d1f94 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 20:48:13 +0300 Subject: removed the option to use 2x more memory when generating previews added an option to always only show one image in previews removed duplicate code --- modules/sd_samplers.py | 35 ++++++++++------------------------- modules/shared.py | 2 +- modules/ui.py | 6 +++--- 3 files changed, 14 insertions(+), 29 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 74a480e5..0b408a70 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -71,6 +71,7 @@ sampler_extra_params = { 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], } + def setup_img2img_steps(p, steps=None): if opts.img2img_fix_steps or steps is not None: steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 @@ -82,37 +83,21 @@ def setup_img2img_steps(p, steps=None): return steps, t_enc -def sample_to_image(samples): - x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0] +def single_sample_to_image(sample): + x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) return Image.fromarray(x_sample) + +def sample_to_image(samples): + return single_sample_to_image(samples[0]) + + def samples_to_image_grid(samples): - progress_images = [] - for i in range(len(samples)): - # Decode the samples individually to reduce VRAM usage at the cost of a bit of speed. - x_sample = processing.decode_first_stage(shared.sd_model, samples[i:i+1])[0] - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - progress_images.append(Image.fromarray(x_sample)) - - return images.image_grid(progress_images) - -def samples_to_image_grid_combined(samples): - progress_images = [] - # Decode all samples at once to increase speed at the cost of VRAM usage. - x_samples = processing.decode_first_stage(shared.sd_model, samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - for x_sample in x_samples: - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - progress_images.append(Image.fromarray(x_sample)) - - return images.image_grid(progress_images) + return images.image_grid([single_sample_to_image(sample) for sample in samples]) + def store_latent(decoded): state.current_latent = decoded diff --git a/modules/shared.py b/modules/shared.py index 95d6e225..25bfc895 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -294,7 +294,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), - "progress_decode_combined": OptionInfo(False, "Decode all progress images at once. (Slighty speeds up progress generation but consumes significantly more VRAM with large batches.)"), + "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), diff --git a/modules/ui.py b/modules/ui.py index de0abc7e..ffa14cac 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -318,10 +318,10 @@ def check_progress_call(id_part): if shared.parallel_processing_allowed: if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None: - if opts.progress_decode_combined: - shared.state.current_image = modules.sd_samplers.samples_to_image_grid_combined(shared.state.current_latent) - else: + if opts.show_progress_grid: shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent) + else: + shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent) shared.state.current_image_sampling_step = shared.state.sampling_step image = shared.state.current_image -- cgit v1.2.3 From be748e8b086bd9834d08bdd9160649a5e7700af7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 22:05:22 +0300 Subject: add --freeze-settings commandline argument to disable changing settings --- modules/shared.py | 1 + modules/ui.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 25bfc895..b55371d3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -64,6 +64,7 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) +parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) diff --git a/modules/ui.py b/modules/ui.py index ffa14cac..2311572c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -580,6 +580,9 @@ def apply_setting(key, value): if value is None: return gr.update() + if shared.cmd_opts.freeze_settings: + return gr.update() + # dont allow model to be swapped when model hash exists in prompt if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: return gr.update() @@ -1501,6 +1504,8 @@ Requested path was: {f} def run_settings(*args): changed = 0 + assert not shared.cmd_opts.freeze_settings, "changing settings is disabled" + for key, value, comp in zip(opts.data_labels.keys(), args, components): if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default): return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson() @@ -1530,6 +1535,8 @@ Requested path was: {f} return f'{changed} settings changed.', opts.dumpjson() def run_settings_single(value, key): + assert not shared.cmd_opts.freeze_settings, "changing settings is disabled" + if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() @@ -1582,7 +1589,7 @@ Requested path was: {f} elem_id, text = item.section gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) - if k in quicksettings_names: + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: quicksettings_list.append((i, k, item)) components.append(dummy_component) else: @@ -1615,7 +1622,7 @@ Requested path was: {f} def reload_scripts(): modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page + reload_javascript() # need to refresh the html page reload_script_bodies.click( fn=reload_scripts, -- cgit v1.2.3 From ca5a9e79dc28eeaa3a161427a82e34703bf15765 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 22:06:54 +0300 Subject: fix for img2img color correction in a batch #3218 --- modules/processing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 27c669b0..b1877b80 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -403,8 +403,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if (len(prompts) == 0): break - #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) - #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) @@ -716,6 +714,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0) if self.overlay_images is not None: self.overlay_images = self.overlay_images * self.batch_size + + if self.color_corrections is not None and len(self.color_corrections) == 1: + self.color_corrections = self.color_corrections * self.batch_size + elif len(imgs) <= self.batch_size: self.batch_size = len(imgs) batch_images = np.array(imgs) -- cgit v1.2.3 From 48dbf99e84045ee7af55bc5b1b86492a240e631e Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 04:17:16 +0900 Subject: Allow tracking real-time loss Someone had 6000 images in their dataset, and it was shown as 0, which was confusing. This will allow tracking real time dataset-average loss for registered objects. --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 81132be4..99fd0f8f 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -360,7 +360,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: hypernetwork.step = i + ititial_step - if loss_dict and i % size == 0: + if len(loss_dict) > 0: previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict) scheduler.apply(optimizer, hypernetwork.step) -- cgit v1.2.3 From ce42879438bf2dbd76b5b346be656292e42ffb2b Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sat, 22 Oct 2022 14:53:37 -0500 Subject: fix js func signature and not forget to initialize confirmation var to prevent exception upon cancelling confirmation --- javascript/ui.js | 7 ++++--- modules/ui.py | 4 +++- 2 files changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/javascript/ui.js b/javascript/ui.js index 6c99824b..39011079 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -162,12 +162,13 @@ function selected_tab_id() { } -function clear_prompt(_, _prompt_neg, confirmed,_steps) { +function clear_prompt(_, _prompt_neg, confirmed, _token_counter) { +confirmed = false if(confirm("Delete prompt?")) { confirmed = true } else { -return [_, confirmed] +return [_, _prompt_neg, confirmed, _token_counter] } if(selected_tab_id() == "tab_txt2img") { @@ -176,7 +177,7 @@ return [_, confirmed] update_token_counter("txt2img_token_button") } - return [_, _prompt_neg, confirmed,_steps] + return [_, _prompt_neg, confirmed, _token_counter] } diff --git a/modules/ui.py b/modules/ui.py index 25aeba3b..e58f040e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -429,10 +429,12 @@ def create_seed_inputs(): return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox -def clear_prompt(_prompt, _prompt_neg, confirmed, _token_counter): +def clear_prompt(prompt, _prompt_neg, confirmed, _token_counter): """Given confirmation from a user on the client-side, go ahead with clearing prompt""" if confirmed: return ["", "", confirmed, update_token_counter("", 1)] + else: + return [prompt, _prompt_neg, confirmed, _token_counter] def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter): -- cgit v1.2.3 From 1b4d04737ac513cbd55958bb60a4f85166f3484b Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 22 Oct 2022 20:13:16 -0300 Subject: Remove unused imports --- modules/api/api.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5b0c934e..a5136b4b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,11 +1,9 @@ from modules.api.processing import StableDiffusionProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, process_images from modules.sd_samplers import all_samplers -from modules.extras import run_pnginfo import modules.shared as shared import uvicorn -from fastapi import Body, APIRouter, HTTPException -from fastapi.responses import JSONResponse +from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field, Json import json import io @@ -18,7 +16,6 @@ class TextToImageResponse(BaseModel): parameters: Json info: Json - class Api: def __init__(self, app, queue_lock): self.router = APIRouter() -- cgit v1.2.3 From b02926df1393df311db734af149fb9faf4389cbe Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 22 Oct 2022 20:24:04 -0300 Subject: Moved moodels to their own file and extracted base64 conversion to its own function --- modules/api/api.py | 17 ++++++----------- modules/api/models.py | 8 ++++++++ 2 files changed, 14 insertions(+), 11 deletions(-) create mode 100644 modules/api/models.py (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index a5136b4b..c17d7580 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -4,17 +4,17 @@ from modules.sd_samplers import all_samplers import modules.shared as shared import uvicorn from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, Field, Json import json import io import base64 +from modules.api.models import * sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -class TextToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: Json - info: Json +def img_to_base64(img): + buffer = io.BytesIO() + img.save(buffer, format="png") + return base64.b64encode(buffer.getvalue()) class Api: def __init__(self, app, queue_lock): @@ -41,15 +41,10 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = [] - for i in processed.images: - buffer = io.BytesIO() - i.save(buffer, format="png") - b64images.append(base64.b64encode(buffer.getvalue())) + b64images = list(map(img_to_base64, processed.images)) return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) - def img2imgapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py new file mode 100644 index 00000000..a7d247d8 --- /dev/null +++ b/modules/api/models.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel, Field, Json + +class TextToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: Json + info: Json + + \ No newline at end of file -- cgit v1.2.3 From 28e26c2bef217ae82eb9e980cceb3f67ef22e109 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 22 Oct 2022 23:13:32 -0300 Subject: Add "extra" single image operation - Separate extra modes into 3 endpoints so the user ddoesn't ahve to handle so many unused parameters. - Add response model for codumentation --- modules/api/api.py | 43 ++++++++++++++++++++++++++++++++++++++----- modules/api/models.py | 26 +++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index c17d7580..3b804373 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -8,20 +8,42 @@ import json import io import base64 from modules.api.models import * +from PIL import Image +from modules.extras import run_extras + +def upscaler_to_index(name: str): + try: + return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) + except: + raise HTTPException(status_code=400, detail="Upscaler not found") sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -def img_to_base64(img): +def img_to_base64(img: str): buffer = io.BytesIO() img.save(buffer, format="png") return base64.b64encode(buffer.getvalue()) +def base64_to_bytes(base64Img: str): + if "," in base64Img: + base64Img = base64Img.split(",")[1] + return io.BytesIO(base64.b64decode(base64Img)) + +def base64_to_images(base64Imgs: list[str]): + imgs = [] + for img in base64Imgs: + img = Image.open(base64_to_bytes(img)) + imgs.append(img) + return imgs + + class Api: def __init__(self, app, queue_lock): self.router = APIRouter() self.app = app self.queue_lock = queue_lock - self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) + self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -45,12 +67,23 @@ class Api: return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) - def img2imgapi(self): raise NotImplementedError - def extrasapi(self): - raise NotImplementedError + def extras_single_image_api(self, req: ExtrasSingleImageRequest): + upscaler1Index = upscaler_to_index(req.upscaler_1) + upscaler2Index = upscaler_to_index(req.upscaler_2) + + reqDict = vars(req) + reqDict.pop('upscaler_1') + reqDict.pop('upscaler_2') + + reqDict['image'] = base64_to_images([reqDict['image']])[0] + + with self.queue_lock: + result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") + + return ExtrasSingleImageResponse(image="data:image/png;base64,"+img_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index a7d247d8..dcf1ab54 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,8 +1,32 @@ from pydantic import BaseModel, Field, Json +from typing_extensions import Literal +from modules.shared import sd_upscalers class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: Json info: Json - \ No newline at end of file +class ExtrasBaseRequest(BaseModel): + resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.") + show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?") + gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") + codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") + codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") + upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") + upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?") + upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") + upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") + extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") + +class ExtraBaseResponse(BaseModel): + html_info_x: str + html_info: str + +class ExtrasSingleImageRequest(ExtrasBaseRequest): + image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") + +class ExtrasSingleImageResponse(ExtraBaseResponse): + image: str = Field(default=None, title="Image", description="The generated image in base64 format.") \ No newline at end of file -- cgit v1.2.3 From 1fbfc052eb529d8cf8ce5baf578bcf93d0280c29 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 23 Oct 2022 05:43:34 +0100 Subject: Update hypernetwork.py --- modules/hypernetworks/hypernetwork.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 99fd0f8f..98a7b62e 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -288,10 +288,13 @@ def statistics(data): def report_statistics(loss_info:dict): keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) for key in keys: - info, recent = statistics(loss_info[key]) - print("Loss statistics for file " + key) - print(info) - print(recent) + try: + print("Loss statistics for file " + key) + info, recent = statistics(loss_info[key]) + print(info) + print(recent) + except Exception as e: + print(e) -- cgit v1.2.3 From a7c213d0f5ebb10722629b8490a5863f9ce6c4fa Mon Sep 17 00:00:00 2001 From: Stephen Date: Fri, 21 Oct 2022 19:27:40 -0400 Subject: [API][Feature] - Add img2img API endpoint --- modules/api/api.py | 58 +++++++++++++++++++++++++++++++++++++++++++---- modules/api/processing.py | 11 +++++++-- modules/processing.py | 2 +- 3 files changed, 63 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5b0c934e..a04f2428 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,5 +1,5 @@ -from modules.api.processing import StableDiffusionProcessingAPI -from modules.processing import StableDiffusionProcessingTxt2Img, process_images +from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers from modules.extras import run_pnginfo import modules.shared as shared @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, Json import json import io import base64 +from PIL import Image sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) @@ -18,6 +19,11 @@ class TextToImageResponse(BaseModel): parameters: Json info: Json +class ImageToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: Json + info: Json + class Api: def __init__(self, app, queue_lock): @@ -25,8 +31,9 @@ class Api: self.app = app self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) - def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) if sampler_index is None: @@ -54,8 +61,49 @@ class Api: - def img2imgapi(self): - raise NotImplementedError + def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): + sampler_index = sampler_to_index(img2imgreq.sampler_index) + + if sampler_index is None: + raise HTTPException(status_code=404, detail="Sampler not found") + + + init_images = img2imgreq.init_images + if init_images is None: + raise HTTPException(status_code=404, detail="Init image not found") + + + populate = img2imgreq.copy(update={ # Override __init__ params + "sd_model": shared.sd_model, + "sampler_index": sampler_index[0], + "do_not_save_samples": True, + "do_not_save_grid": True + } + ) + p = StableDiffusionProcessingImg2Img(**vars(populate)) + + imgs = [] + for img in init_images: + # if has a comma, deal with prefix + if "," in img: + img = img.split(",")[1] + # convert base64 to PIL image + img = base64.b64decode(img) + img = Image.open(io.BytesIO(img)) + imgs = [img] * p.batch_size + + p.init_images = imgs + # Override object param + with self.queue_lock: + processed = process_images(p) + + b64images = [] + for i in processed.images: + buffer = io.BytesIO() + i.save(buffer, format="png") + b64images.append(base64.b64encode(buffer.getvalue())) + + return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) def extrasapi(self): raise NotImplementedError diff --git a/modules/api/processing.py b/modules/api/processing.py index 4c541241..9f1d65c0 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -1,7 +1,8 @@ +from array import array from inflection import underscore from typing import Any, Dict, Optional from pydantic import BaseModel, Field, create_model -from modules.processing import StableDiffusionProcessingTxt2Img +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img import inspect @@ -92,8 +93,14 @@ class PydanticModelGenerator: DynamicModel.__config__.allow_mutation = True return DynamicModel -StableDiffusionProcessingAPI = PydanticModelGenerator( +StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, [{"key": "sampler_index", "type": str, "default": "Euler"}] +).generate_model() + +StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingImg2Img", + StableDiffusionProcessingImg2Img, + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}] ).generate_model() \ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index b1877b80..1557ed8c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -623,7 +623,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None - def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs): + def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: str=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs): super().__init__(**kwargs) self.init_images = init_images -- cgit v1.2.3 From 9e1a8b7734a2881451a2efbf80def011ea41ba49 Mon Sep 17 00:00:00 2001 From: Stephen Date: Sat, 22 Oct 2022 15:42:00 -0400 Subject: non-implemented mask with any type --- modules/api/api.py | 4 ++++ modules/api/processing.py | 2 +- modules/processing.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index a04f2428..3df6ff96 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -72,6 +72,10 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") + mask = img2imgreq.mask + if mask: + raise HTTPException(status_code=400, detail="Mask not supported yet") + populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, diff --git a/modules/api/processing.py b/modules/api/processing.py index 9f1d65c0..f551fa35 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -102,5 +102,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] ).generate_model() \ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index 1557ed8c..ff83023c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -623,7 +623,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None - def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: str=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs): + def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs): super().__init__(**kwargs) self.init_images = init_images -- cgit v1.2.3 From 5dc0739ecdc1ade8fcf4eb77f2a503ef12489f32 Mon Sep 17 00:00:00 2001 From: Stephen Date: Sat, 22 Oct 2022 17:10:28 -0400 Subject: working mask --- modules/api/api.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3df6ff96..3caa83a4 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -33,6 +33,14 @@ class Api: self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) + def __base64_to_image(self, base64_string): + # if has a comma, deal with prefix + if "," in base64_string: + base64_string = base64_string.split(",")[1] + imgdata = base64.b64decode(base64_string) + # convert base64 to PIL image + return Image.open(io.BytesIO(imgdata)) + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -74,26 +82,22 @@ class Api: mask = img2imgreq.mask if mask: - raise HTTPException(status_code=400, detail="Mask not supported yet") + mask = self.__base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, "sampler_index": sampler_index[0], "do_not_save_samples": True, - "do_not_save_grid": True + "do_not_save_grid": True, + "mask": mask } ) p = StableDiffusionProcessingImg2Img(**vars(populate)) imgs = [] for img in init_images: - # if has a comma, deal with prefix - if "," in img: - img = img.split(",")[1] - # convert base64 to PIL image - img = base64.b64decode(img) - img = Image.open(io.BytesIO(img)) + img = self.__base64_to_image(img) imgs = [img] * p.batch_size p.init_images = imgs -- cgit v1.2.3 From 1be5933ba21a3badec42b7b2753d626f849b609d Mon Sep 17 00:00:00 2001 From: captin411 Date: Sun, 23 Oct 2022 04:11:07 -0700 Subject: auto cropping now works with non square crops --- modules/textual_inversion/autocrop.py | 509 ++++++++++++++++++---------------- 1 file changed, 269 insertions(+), 240 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 5a551c25..b2f9241c 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -1,241 +1,270 @@ -import cv2 -from collections import defaultdict -from math import log, sqrt -import numpy as np -from PIL import Image, ImageDraw - -GREEN = "#0F0" -BLUE = "#00F" -RED = "#F00" - - -def crop_image(im, settings): - """ Intelligently crop an image to the subject matter """ - if im.height > im.width: - im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) - elif im.width > im.height: - im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) - else: - im = im.resize((settings.crop_width, settings.crop_height)) - - if im.height == im.width: - return im - - focus = focal_point(im, settings) - - # take the focal point and turn it into crop coordinates that try to center over the focal - # point but then get adjusted back into the frame - y_half = int(settings.crop_height / 2) - x_half = int(settings.crop_width / 2) - - x1 = focus.x - x_half - if x1 < 0: - x1 = 0 - elif x1 + settings.crop_width > im.width: - x1 = im.width - settings.crop_width - - y1 = focus.y - y_half - if y1 < 0: - y1 = 0 - elif y1 + settings.crop_height > im.height: - y1 = im.height - settings.crop_height - - x2 = x1 + settings.crop_width - y2 = y1 + settings.crop_height - - crop = [x1, y1, x2, y2] - - if settings.annotate_image: - d = ImageDraw.Draw(im) - rect = list(crop) - rect[2] -= 1 - rect[3] -= 1 - d.rectangle(rect, outline=GREEN) - if settings.destop_view_image: - im.show() - - return im.crop(tuple(crop)) - -def focal_point(im, settings): - corner_points = image_corner_points(im, settings) - entropy_points = image_entropy_points(im, settings) - face_points = image_face_points(im, settings) - - total_points = len(corner_points) + len(entropy_points) + len(face_points) - - corner_weight = settings.corner_points_weight - entropy_weight = settings.entropy_points_weight - face_weight = settings.face_points_weight - - weight_pref_total = corner_weight + entropy_weight + face_weight - - # weight things - pois = [] - if weight_pref_total == 0 or total_points == 0: - return pois - - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] - ) - - average_point = poi_average(pois, settings) - - if settings.annotate_image: - d = ImageDraw.Draw(im) - for f in face_points: - d.rectangle(f.bounding(f.size), outline=RED) - for f in entropy_points: - d.rectangle(f.bounding(30), outline=BLUE) - for poi in pois: - w = max(4, 4 * 0.5 * sqrt(poi.weight)) - d.ellipse(poi.bounding(w), fill=BLUE) - d.ellipse(average_point.bounding(25), outline=GREEN) - - return average_point - - -def image_face_points(im, settings): - np_im = np.array(im) - gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - - tries = [ - [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] - ] - - for t in tries: - # print(t[0]) - classifier = cv2.CascadeClassifier(t[0]) - minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side - try: - faces = classifier.detectMultiScale(gray, scaleFactor=1.1, - minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - except: - continue - - if len(faces) > 0: - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] - return [] - - -def image_corner_points(im, settings): - grayscale = im.convert("L") - - # naive attempt at preventing focal points from collecting at watermarks near the bottom - gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") - - np_im = np.array(grayscale) - - points = cv2.goodFeaturesToTrack( - np_im, - maxCorners=100, - qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.07, - useHarrisDetector=False, - ) - - if points is None: - return [] - - focal_points = [] - for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y, size=4)) - - return focal_points - - -def image_entropy_points(im, settings): - landscape = im.height < im.width - portrait = im.height > im.width - if landscape: - move_idx = [0, 2] - move_max = im.size[0] - elif portrait: - move_idx = [1, 3] - move_max = im.size[1] - else: - return [] - - e_max = 0 - crop_current = [0, 0, settings.crop_width, settings.crop_height] - crop_best = crop_current - while crop_current[move_idx[1]] < move_max: - crop = im.crop(tuple(crop_current)) - e = image_entropy(crop) - - if (e > e_max): - e_max = e - crop_best = list(crop_current) - - crop_current[move_idx[0]] += 4 - crop_current[move_idx[1]] += 4 - - x_mid = int(crop_best[0] + settings.crop_width/2) - y_mid = int(crop_best[1] + settings.crop_height/2) - - return [PointOfInterest(x_mid, y_mid, size=25)] - - -def image_entropy(im): - # greyscale image entropy - # band = np.asarray(im.convert("L")) - band = np.asarray(im.convert("1"), dtype=np.uint8) - hist, _ = np.histogram(band, bins=range(0, 256)) - hist = hist[hist > 0] - return -np.log2(hist / hist.sum()).sum() - - -def poi_average(pois, settings): - weight = 0.0 - x = 0.0 - y = 0.0 - for poi in pois: - weight += poi.weight - x += poi.x * poi.weight - y += poi.y * poi.weight - avg_x = round(x / weight) - avg_y = round(y / weight) - - return PointOfInterest(avg_x, avg_y) - - -class PointOfInterest: - def __init__(self, x, y, weight=1.0, size=10): - self.x = x - self.y = y - self.weight = weight - self.size = size - - def bounding(self, size): - return [ - self.x - size//2, - self.y - size//2, - self.x + size//2, - self.y + size//2 - ] - - -class Settings: - def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): - self.crop_width = crop_width - self.crop_height = crop_height - self.corner_points_weight = corner_points_weight - self.entropy_points_weight = entropy_points_weight - self.face_points_weight = entropy_points_weight - self.annotate_image = annotate_image +import cv2 +from collections import defaultdict +from math import log, sqrt +import numpy as np +from PIL import Image, ImageDraw + +GREEN = "#0F0" +BLUE = "#00F" +RED = "#F00" + + +def crop_image(im, settings): + """ Intelligently crop an image to the subject matter """ + + scale_by = 1 + if is_landscape(im.width, im.height): + scale_by = settings.crop_height / im.height + elif is_portrait(im.width, im.height): + scale_by = settings.crop_width / im.width + elif is_square(im.width, im.height): + if is_square(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_landscape(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_portrait(settings.crop_width, settings.crop_height): + scale_by = settings.crop_height / im.height + + im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) + + if im.width == settings.crop_width and im.height == settings.crop_height: + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = [0, 0, im.width, im.height] + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + return im + + focus = focal_point(im, settings) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(settings.crop_height / 2) + x_half = int(settings.crop_width / 2) + + x1 = focus.x - x_half + if x1 < 0: + x1 = 0 + elif x1 + settings.crop_width > im.width: + x1 = im.width - settings.crop_width + + y1 = focus.y - y_half + if y1 < 0: + y1 = 0 + elif y1 + settings.crop_height > im.height: + y1 = im.height - settings.crop_height + + x2 = x1 + settings.crop_width + y2 = y1 + settings.crop_height + + crop = [x1, y1, x2, y2] + + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = list(crop) + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + + return im.crop(tuple(crop)) + +def focal_point(im, settings): + corner_points = image_corner_points(im, settings) + entropy_points = image_entropy_points(im, settings) + face_points = image_face_points(im, settings) + + total_points = len(corner_points) + len(entropy_points) + len(face_points) + + corner_weight = settings.corner_points_weight + entropy_weight = settings.entropy_points_weight + face_weight = settings.face_points_weight + + weight_pref_total = corner_weight + entropy_weight + face_weight + + # weight things + pois = [] + if weight_pref_total == 0 or total_points == 0: + return pois + + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] + ) + + average_point = poi_average(pois, settings) + + if settings.annotate_image: + d = ImageDraw.Draw(im) + for f in face_points: + d.rectangle(f.bounding(f.size), outline=RED) + for f in entropy_points: + d.rectangle(f.bounding(30), outline=BLUE) + for poi in pois: + w = max(4, 4 * 0.5 * sqrt(poi.weight)) + d.ellipse(poi.bounding(w), fill=BLUE) + d.ellipse(average_point.bounding(25), outline=GREEN) + + return average_point + + +def image_face_points(im, settings): + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + + tries = [ + [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] + ] + + for t in tries: + # print(t[0]) + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + except: + continue + + if len(faces) > 0: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] + return [] + + +def image_corner_points(im, settings): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=100, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.07, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4)) + + return focal_points + + +def image_entropy_points(im, settings): + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] + else: + return [] + + e_max = 0 + crop_current = [0, 0, settings.crop_width, settings.crop_height] + crop_best = crop_current + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e > e_max): + e_max = e + crop_best = list(crop_current) + + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + settings.crop_width/2) + y_mid = int(crop_best[1] + settings.crop_height/2) + + return [PointOfInterest(x_mid, y_mid, size=25)] + + +def image_entropy(im): + # greyscale image entropy + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + + +def poi_average(pois, settings): + weight = 0.0 + x = 0.0 + y = 0.0 + for poi in pois: + weight += poi.weight + x += poi.x * poi.weight + y += poi.y * poi.weight + avg_x = round(x / weight) + avg_y = round(y / weight) + + return PointOfInterest(avg_x, avg_y) + + +def is_landscape(w, h): + return w > h + + +def is_portrait(w, h): + return h > w + + +def is_square(w, h): + return w == h + + +class PointOfInterest: + def __init__(self, x, y, weight=1.0, size=10): + self.x = x + self.y = y + self.weight = weight + self.size = size + + def bounding(self, size): + return [ + self.x - size//2, + self.y - size//2, + self.x + size//2, + self.y + size//2 + ] + + +class Settings: + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): + self.crop_width = crop_width + self.crop_height = crop_height + self.corner_points_weight = corner_points_weight + self.entropy_points_weight = entropy_points_weight + self.face_points_weight = entropy_points_weight + self.annotate_image = annotate_image self.destop_view_image = False \ No newline at end of file -- cgit v1.2.3 From 0523704dade0508bf3ae0c8cb799b1ae332d449b Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 12:27:50 -0300 Subject: Update run_extras to use the temp filename In batch mode run_extras tries to preserve the original file name of the images. The problem is that this makes no sense since the user only gets a list of images in the UI, trying to manually save them shows that this images have random temp names. Also, trying to keep "orig_name" in the API is a hassle that adds complexity to the consuming UI since the client has to use (or emulate) an input (type=file) element in a form. Using the normal file name not only doesn't change the output and functionality in the original UI but also helps keep the API simple. --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 22c5a1c1..29ac312e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -33,7 +33,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ for img in image_folder: image = Image.open(img) imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) + imageNameArr.append(os.path.splitext(img.name)[0]) elif extras_mode == 2: assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' -- cgit v1.2.3 From 4ff852ffb50859f2eae75375cab94dd790a46886 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 13:07:59 -0300 Subject: Add batch processing "extras" endpoint --- modules/api/api.py | 25 +++++++++++++++++++++++-- modules/api/models.py | 15 ++++++++++++++- 2 files changed, 37 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3b804373..528134a8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -10,6 +10,7 @@ import base64 from modules.api.models import * from PIL import Image from modules.extras import run_extras +from gradio import processing_utils def upscaler_to_index(name: str): try: @@ -44,6 +45,7 @@ class Api: self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) + self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -78,12 +80,31 @@ class Api: reqDict.pop('upscaler_1') reqDict.pop('upscaler_2') - reqDict['image'] = base64_to_images([reqDict['image']])[0] + reqDict['image'] = processing_utils.decode_base64_to_file(reqDict['image']) with self.queue_lock: result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") - return ExtrasSingleImageResponse(image="data:image/png;base64,"+img_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) + + def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): + upscaler1Index = upscaler_to_index(req.upscaler_1) + upscaler2Index = upscaler_to_index(req.upscaler_2) + + reqDict = vars(req) + reqDict.pop('upscaler_1') + reqDict.pop('upscaler_2') + + reqDict['image_folder'] = list(map(processing_utils.decode_base64_to_file, reqDict['imageList'])) + reqDict.pop('imageList') + + with self.queue_lock: + result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=1, image="", input_dir="", output_dir="") + + return ExtrasBatchImagesResponse(images=list(map(processing_utils.encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + + def extras_folder_processing_api(self): + raise NotImplementedError def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index dcf1ab54..bbd0ef53 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -29,4 +29,17 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") class ExtrasSingleImageResponse(ExtraBaseResponse): - image: str = Field(default=None, title="Image", description="The generated image in base64 format.") \ No newline at end of file + image: str = Field(default=None, title="Image", description="The generated image in base64 format.") + +class SerializableImage(BaseModel): + path: str = Field(title="Path", description="The image's path ()") + +class ImageItem(BaseModel): + data: str = Field(title="image data") + name: str = Field(title="filename") + +class ExtrasBatchImagesRequest(ExtrasBaseRequest): + imageList: list[str] = Field(title="Images", description="List of images to work on. Must be Base64 strings") + +class ExtrasBatchImagesResponse(ExtraBaseResponse): + images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file -- cgit v1.2.3 From e0ca4dfbc10e0af8dfc4185e5e758f33fd2f0d81 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 15:13:37 -0300 Subject: Update endpoints to use gradio's own utils functions --- modules/api/api.py | 75 +++++++++++++++++++++++++-------------------------- modules/api/models.py | 4 +-- 2 files changed, 38 insertions(+), 41 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3f490ce2..3acb1f36 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -20,27 +20,27 @@ def upscaler_to_index(name: str): sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -def img_to_base64(img: str): - buffer = io.BytesIO() - img.save(buffer, format="png") - return base64.b64encode(buffer.getvalue()) - -def base64_to_bytes(base64Img: str): - if "," in base64Img: - base64Img = base64Img.split(",")[1] - return io.BytesIO(base64.b64decode(base64Img)) - -def base64_to_images(base64Imgs: list[str]): - imgs = [] - for img in base64Imgs: - img = Image.open(base64_to_bytes(img)) - imgs.append(img) - return imgs +# def img_to_base64(img: str): +# buffer = io.BytesIO() +# img.save(buffer, format="png") +# return base64.b64encode(buffer.getvalue()) + +# def base64_to_bytes(base64Img: str): +# if "," in base64Img: +# base64Img = base64Img.split(",")[1] +# return io.BytesIO(base64.b64decode(base64Img)) + +# def base64_to_images(base64Imgs: list[str]): +# imgs = [] +# for img in base64Imgs: +# img = Image.open(base64_to_bytes(img)) +# imgs.append(img) +# return imgs class ImageToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: Json - info: Json + parameters: dict + info: str class Api: @@ -49,17 +49,17 @@ class Api: self.app = app self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) - self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - def __base64_to_image(self, base64_string): - # if has a comma, deal with prefix - if "," in base64_string: - base64_string = base64_string.split(",")[1] - imgdata = base64.b64decode(base64_string) - # convert base64 to PIL image - return Image.open(io.BytesIO(imgdata)) + # def __base64_to_image(self, base64_string): + # # if has a comma, deal with prefix + # if "," in base64_string: + # base64_string = base64_string.split(",")[1] + # imgdata = base64.b64decode(base64_string) + # # convert base64 to PIL image + # return Image.open(io.BytesIO(imgdata)) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -79,11 +79,9 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = list(map(img_to_base64, processed.images)) - - return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) - + b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.info) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): sampler_index = sampler_to_index(img2imgreq.sampler_index) @@ -98,7 +96,7 @@ class Api: mask = img2imgreq.mask if mask: - mask = self.__base64_to_image(mask) + mask = processing_utils.decode_base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params @@ -113,7 +111,7 @@ class Api: imgs = [] for img in init_images: - img = self.__base64_to_image(img) + img = processing_utils.decode_base64_to_image(img) imgs = [img] * p.batch_size p.init_images = imgs @@ -121,13 +119,12 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = [] - for i in processed.images: - buffer = io.BytesIO() - i.save(buffer, format="png") - b64images.append(base64.b64encode(buffer.getvalue())) - - return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) + b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + # for i in processed.images: + # buffer = io.BytesIO() + # i.save(buffer, format="png") + # b64images.append(base64.b64encode(buffer.getvalue())) + return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info) def extras_single_image_api(self, req: ExtrasSingleImageRequest): upscaler1Index = upscaler_to_index(req.upscaler_1) diff --git a/modules/api/models.py b/modules/api/models.py index bbd0ef53..209f8af5 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -4,8 +4,8 @@ from modules.shared import sd_upscalers class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: Json - info: Json + parameters: str + info: str class ExtrasBaseRequest(BaseModel): resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.") -- cgit v1.2.3 From 866b36d705a338d299aba385788729d60f7d48c8 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 15:35:49 -0300 Subject: Move processing's models into models.py It didn't make sense to have two differente files for the same and "models" is a more descriptive name. --- modules/api/api.py | 57 ++++------------------- modules/api/models.py | 112 +++++++++++++++++++++++++++++++++++++++++++++- modules/api/processing.py | 106 ------------------------------------------- 3 files changed, 119 insertions(+), 156 deletions(-) delete mode 100644 modules/api/processing.py (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3acb1f36..20e85e82 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,16 +1,11 @@ -from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI -from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers -import modules.shared as shared import uvicorn +from gradio import processing_utils from fastapi import APIRouter, HTTPException -import json -import io -import base64 +import modules.shared as shared from modules.api.models import * -from PIL import Image +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +from modules.sd_samplers import all_samplers from modules.extras import run_extras -from gradio import processing_utils def upscaler_to_index(name: str): try: @@ -20,29 +15,6 @@ def upscaler_to_index(name: str): sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -# def img_to_base64(img: str): -# buffer = io.BytesIO() -# img.save(buffer, format="png") -# return base64.b64encode(buffer.getvalue()) - -# def base64_to_bytes(base64Img: str): -# if "," in base64Img: -# base64Img = base64Img.split(",")[1] -# return io.BytesIO(base64.b64decode(base64Img)) - -# def base64_to_images(base64Imgs: list[str]): -# imgs = [] -# for img in base64Imgs: -# img = Image.open(base64_to_bytes(img)) -# imgs.append(img) -# return imgs - -class ImageToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: dict - info: str - - class Api: def __init__(self, app, queue_lock): self.router = APIRouter() @@ -51,15 +23,7 @@ class Api: self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) - self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - - # def __base64_to_image(self, base64_string): - # # if has a comma, deal with prefix - # if "," in base64_string: - # base64_string = base64_string.split(",")[1] - # imgdata = base64.b64decode(base64_string) - # # convert base64 to PIL image - # return Image.open(io.BytesIO(imgdata)) + self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -81,7 +45,7 @@ class Api: b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) - return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.info) + return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.info) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): sampler_index = sampler_to_index(img2imgreq.sampler_index) @@ -120,10 +84,7 @@ class Api: processed = process_images(p) b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) - # for i in processed.images: - # buffer = io.BytesIO() - # i.save(buffer, format="png") - # b64images.append(base64.b64encode(buffer.getvalue())) + return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info) def extras_single_image_api(self, req: ExtrasSingleImageRequest): @@ -134,12 +95,12 @@ class Api: reqDict.pop('upscaler_1') reqDict.pop('upscaler_2') - reqDict['image'] = processing_utils.decode_base64_to_file(reqDict['image']) + reqDict['image'] = processing_utils.decode_base64_to_image(reqDict['image']) with self.queue_lock: result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") - return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): upscaler1Index = upscaler_to_index(req.upscaler_1) diff --git a/modules/api/models.py b/modules/api/models.py index 209f8af5..362e6277 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,10 +1,118 @@ -from pydantic import BaseModel, Field, Json +import inspect +from pydantic import BaseModel, Field, Json, create_model +from typing import Any, Optional from typing_extensions import Literal +from inflection import underscore +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.shared import sd_upscalers +API_NOT_ALLOWED = [ + "self", + "kwargs", + "sd_model", + "outpath_samples", + "outpath_grids", + "sampler_index", + "do_not_save_samples", + "do_not_save_grid", + "extra_generation_params", + "overlay_images", + "do_not_reload_embeddings", + "seed_enable_extras", + "prompt_for_display", + "sampler_noise_scheduler_override", + "ddim_discretize" +] + +class ModelDef(BaseModel): + """Assistance Class for Pydantic Dynamic Model Generation""" + + field: str + field_alias: str + field_type: Any + field_value: Any + + +class PydanticModelGenerator: + """ + Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: + source_data is a snapshot of the default values produced by the class + params are the names of the actual keys required by __init__ + """ + + def __init__( + self, + model_name: str = None, + class_instance = None, + additional_fields = None, + ): + def field_type_generator(k, v): + # field_type = str if not overrides.get(k) else overrides[k]["type"] + # print(k, v.annotation, v.default) + field_type = v.annotation + + return Optional[field_type] + + def merge_class_params(class_): + all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) + parameters = {} + for classes in all_classes: + parameters = {**parameters, **inspect.signature(classes.__init__).parameters} + return parameters + + + self._model_name = model_name + self._class_data = merge_class_params(class_instance) + self._model_def = [ + ModelDef( + field=underscore(k), + field_alias=k, + field_type=field_type_generator(k, v), + field_value=v.default + ) + for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED + ] + + for fields in additional_fields: + self._model_def.append(ModelDef( + field=underscore(fields["key"]), + field_alias=fields["key"], + field_type=fields["type"], + field_value=fields["default"])) + + def generate_model(self): + """ + Creates a pydantic BaseModel + from the json and overrides provided at initialization + """ + fields = { + d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def + } + DynamicModel = create_model(self._model_name, **fields) + DynamicModel.__config__.allow_population_by_field_name = True + DynamicModel.__config__.allow_mutation = True + return DynamicModel + +StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingTxt2Img", + StableDiffusionProcessingTxt2Img, + [{"key": "sampler_index", "type": str, "default": "Euler"}] +).generate_model() + +StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingImg2Img", + StableDiffusionProcessingImg2Img, + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] +).generate_model() + class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: str + parameters: dict + info: str + +class ImageToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: dict info: str class ExtrasBaseRequest(BaseModel): diff --git a/modules/api/processing.py b/modules/api/processing.py deleted file mode 100644 index f551fa35..00000000 --- a/modules/api/processing.py +++ /dev/null @@ -1,106 +0,0 @@ -from array import array -from inflection import underscore -from typing import Any, Dict, Optional -from pydantic import BaseModel, Field, create_model -from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img -import inspect - - -API_NOT_ALLOWED = [ - "self", - "kwargs", - "sd_model", - "outpath_samples", - "outpath_grids", - "sampler_index", - "do_not_save_samples", - "do_not_save_grid", - "extra_generation_params", - "overlay_images", - "do_not_reload_embeddings", - "seed_enable_extras", - "prompt_for_display", - "sampler_noise_scheduler_override", - "ddim_discretize" -] - -class ModelDef(BaseModel): - """Assistance Class for Pydantic Dynamic Model Generation""" - - field: str - field_alias: str - field_type: Any - field_value: Any - - -class PydanticModelGenerator: - """ - Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: - source_data is a snapshot of the default values produced by the class - params are the names of the actual keys required by __init__ - """ - - def __init__( - self, - model_name: str = None, - class_instance = None, - additional_fields = None, - ): - def field_type_generator(k, v): - # field_type = str if not overrides.get(k) else overrides[k]["type"] - # print(k, v.annotation, v.default) - field_type = v.annotation - - return Optional[field_type] - - def merge_class_params(class_): - all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) - parameters = {} - for classes in all_classes: - parameters = {**parameters, **inspect.signature(classes.__init__).parameters} - return parameters - - - self._model_name = model_name - self._class_data = merge_class_params(class_instance) - self._model_def = [ - ModelDef( - field=underscore(k), - field_alias=k, - field_type=field_type_generator(k, v), - field_value=v.default - ) - for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED - ] - - for fields in additional_fields: - self._model_def.append(ModelDef( - field=underscore(fields["key"]), - field_alias=fields["key"], - field_type=fields["type"], - field_value=fields["default"])) - - def generate_model(self): - """ - Creates a pydantic BaseModel - from the json and overrides provided at initialization - """ - fields = { - d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def - } - DynamicModel = create_model(self._model_name, **fields) - DynamicModel.__config__.allow_population_by_field_name = True - DynamicModel.__config__.allow_mutation = True - return DynamicModel - -StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( - "StableDiffusionProcessingTxt2Img", - StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}] -).generate_model() - -StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( - "StableDiffusionProcessingImg2Img", - StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] -).generate_model() \ No newline at end of file -- cgit v1.2.3 From 1e625624ba6ab3dfc167f0a5226780bb9b50fb58 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 16:01:16 -0300 Subject: Add folder processing endpoint Also minor refactor --- modules/api/api.py | 56 +++++++++++++++++++++++++++------------------------ modules/api/models.py | 6 +++++- 2 files changed, 35 insertions(+), 27 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 20e85e82..7b4fbe29 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,5 +1,5 @@ import uvicorn -from gradio import processing_utils +from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, HTTPException import modules.shared as shared from modules.api.models import * @@ -11,10 +11,18 @@ def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) except: - raise HTTPException(status_code=400, detail="Upscaler not found") + raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) +def setUpscalers(req: dict): + reqDict = vars(req) + reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) + reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2) + reqDict.pop('upscaler_1') + reqDict.pop('upscaler_2') + return reqDict + class Api: def __init__(self, app, queue_lock): self.router = APIRouter() @@ -24,6 +32,7 @@ class Api: self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) + self.app.add_api_route("/sdapi/v1/extra-folder-images", self.extras_folder_processing_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -43,7 +52,7 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.info) @@ -60,7 +69,7 @@ class Api: mask = img2imgreq.mask if mask: - mask = processing_utils.decode_base64_to_image(mask) + mask = decode_base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params @@ -75,7 +84,7 @@ class Api: imgs = [] for img in init_images: - img = processing_utils.decode_base64_to_image(img) + img = decode_base64_to_image(img) imgs = [img] * p.batch_size p.init_images = imgs @@ -83,43 +92,38 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info) def extras_single_image_api(self, req: ExtrasSingleImageRequest): - upscaler1Index = upscaler_to_index(req.upscaler_1) - upscaler2Index = upscaler_to_index(req.upscaler_2) - - reqDict = vars(req) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') + reqDict = setUpscalers(req) - reqDict['image'] = processing_utils.decode_base64_to_image(reqDict['image']) + reqDict['image'] = decode_base64_to_image(reqDict['image']) with self.queue_lock: - result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") + result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict) - return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): - upscaler1Index = upscaler_to_index(req.upscaler_1) - upscaler2Index = upscaler_to_index(req.upscaler_2) + reqDict = setUpscalers(req) - reqDict = vars(req) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') - - reqDict['image_folder'] = list(map(processing_utils.decode_base64_to_file, reqDict['imageList'])) + reqDict['image_folder'] = list(map(decode_base64_to_file, reqDict['imageList'])) reqDict.pop('imageList') with self.queue_lock: - result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=1, image="", input_dir="", output_dir="") + result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict) - return ExtrasBatchImagesResponse(images=list(map(processing_utils.encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) - def extras_folder_processing_api(self): - raise NotImplementedError + def extras_folder_processing_api(self, req:ExtrasFoldersRequest): + reqDict = setUpscalers(req) + + with self.queue_lock: + result = run_extras(extras_mode=2, image=None, image_folder=None, **reqDict) + + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index 362e6277..6f096807 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -150,4 +150,8 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest): imageList: list[str] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file + images: list[str] = Field(title="Images", description="The generated images in base64 format.") + +class ExtrasFoldersRequest(ExtrasBaseRequest): + input_dir: str = Field(title="Input directory", description="Directory path from where to take the images") + output_dir: str = Field(title="Output directory", description="Directory path to put the processsed images into") -- cgit v1.2.3 From 90f02c75220d187e075203a4e3b450bfba392c4d Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 16:03:30 -0300 Subject: Remove unused field and class --- modules/api/api.py | 6 +++--- modules/api/models.py | 6 +----- 2 files changed, 4 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 7b4fbe29..799e3701 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -104,7 +104,7 @@ class Api: with self.queue_lock: result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict) - return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): reqDict = setUpscalers(req) @@ -115,7 +115,7 @@ class Api: with self.queue_lock: result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict) - return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) def extras_folder_processing_api(self, req:ExtrasFoldersRequest): reqDict = setUpscalers(req) @@ -123,7 +123,7 @@ class Api: with self.queue_lock: result = run_extras(extras_mode=2, image=None, image_folder=None, **reqDict) - return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index 6f096807..e461d397 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -130,8 +130,7 @@ class ExtrasBaseRequest(BaseModel): extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") class ExtraBaseResponse(BaseModel): - html_info_x: str - html_info: str + html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.") class ExtrasSingleImageRequest(ExtrasBaseRequest): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") @@ -139,9 +138,6 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest): class ExtrasSingleImageResponse(ExtraBaseResponse): image: str = Field(default=None, title="Image", description="The generated image in base64 format.") -class SerializableImage(BaseModel): - path: str = Field(title="Path", description="The image's path ()") - class ImageItem(BaseModel): data: str = Field(title="image data") name: str = Field(title="filename") -- cgit v1.2.3 From 124e44cf1eed1edc68954f63a2a9bc428aabbcec Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 24 Oct 2022 09:51:56 +0800 Subject: remove browser to extension --- .gitignore | 1 - javascript/images_history.js | 200 -------------------- javascript/inspiration.js | 48 ----- modules/images_history.py | 424 ------------------------------------------- modules/inspiration.py | 193 -------------------- modules/script_callbacks.py | 2 - modules/shared.py | 15 -- modules/ui.py | 20 +- 8 files changed, 4 insertions(+), 899 deletions(-) delete mode 100644 javascript/images_history.js delete mode 100644 javascript/inspiration.js delete mode 100644 modules/images_history.py delete mode 100644 modules/inspiration.py (limited to 'modules') diff --git a/.gitignore b/.gitignore index 8d01bc6a..70660c51 100644 --- a/.gitignore +++ b/.gitignore @@ -29,5 +29,4 @@ notification.mp3 /textual_inversion .vscode /extensions - /inspiration diff --git a/javascript/images_history.js b/javascript/images_history.js deleted file mode 100644 index c9aa76f8..00000000 --- a/javascript/images_history.js +++ /dev/null @@ -1,200 +0,0 @@ -var images_history_click_image = function(){ - if (!this.classList.contains("transform")){ - var gallery = images_history_get_parent_by_class(this, "images_history_cantainor"); - var buttons = gallery.querySelectorAll(".gallery-item"); - var i = 0; - var hidden_list = []; - buttons.forEach(function(e){ - if (e.style.display == "none"){ - hidden_list.push(i); - } - i += 1; - }) - if (hidden_list.length > 0){ - setTimeout(images_history_hide_buttons, 10, hidden_list, gallery); - } - } - images_history_set_image_info(this); -} - -function images_history_disabled_del(){ - gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){ - btn.setAttribute('disabled','disabled'); - }); -} - -function images_history_get_parent_by_class(item, class_name){ - var parent = item.parentElement; - while(!parent.classList.contains(class_name)){ - parent = parent.parentElement; - } - return parent; -} - -function images_history_get_parent_by_tagname(item, tagname){ - var parent = item.parentElement; - tagname = tagname.toUpperCase() - while(parent.tagName != tagname){ - parent = parent.parentElement; - } - return parent; -} - -function images_history_hide_buttons(hidden_list, gallery){ - var buttons = gallery.querySelectorAll(".gallery-item"); - var num = 0; - buttons.forEach(function(e){ - if (e.style.display == "none"){ - num += 1; - } - }); - if (num == hidden_list.length){ - setTimeout(images_history_hide_buttons, 10, hidden_list, gallery); - } - for( i in hidden_list){ - buttons[hidden_list[i]].style.display = "none"; - } -} - -function images_history_set_image_info(button){ - var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item"); - var index = -1; - var i = 0; - buttons.forEach(function(e){ - if(e == button){ - index = i; - } - if(e.style.display != "none"){ - i += 1; - } - }); - var gallery = images_history_get_parent_by_class(button, "images_history_cantainor"); - var set_btn = gallery.querySelector(".images_history_set_index"); - var curr_idx = set_btn.getAttribute("img_index", index); - if (curr_idx != index) { - set_btn.setAttribute("img_index", index); - images_history_disabled_del(); - } - set_btn.click(); - -} - -function images_history_get_current_img(tabname, img_index, files){ - return [ - tabname, - gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"), - files - ]; -} - -function images_history_delete(del_num, tabname, image_index){ - image_index = parseInt(image_index); - var tab = gradioApp().getElementById(tabname + '_images_history'); - var set_btn = tab.querySelector(".images_history_set_index"); - var buttons = []; - tab.querySelectorAll(".gallery-item").forEach(function(e){ - if (e.style.display != 'none'){ - buttons.push(e); - } - }); - var img_num = buttons.length / 2; - del_num = Math.min(img_num - image_index, del_num) - if (img_num <= del_num){ - setTimeout(function(tabname){ - gradioApp().getElementById(tabname + '_images_history_renew_page').click(); - }, 30, tabname); - } else { - var next_img - for (var i = 0; i < del_num; i++){ - buttons[image_index + i].style.display = 'none'; - buttons[image_index + i + img_num].style.display = 'none'; - next_img = image_index + i + 1 - } - var bnt; - if (next_img >= img_num){ - btn = buttons[image_index - 1]; - } else { - btn = buttons[next_img]; - } - setTimeout(function(btn){btn.click()}, 30, btn); - } - images_history_disabled_del(); - -} - -function images_history_turnpage(tabname){ - gradioApp().getElementById(tabname + '_images_history_del_button').setAttribute('disabled','disabled'); - var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item"); - buttons.forEach(function(elem) { - elem.style.display = 'block'; - }) -} - -function images_history_enable_del_buttons(){ - gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){ - btn.removeAttribute('disabled'); - }) -} - -function images_history_init(){ - var tabnames = gradioApp().getElementById("images_history_tabnames_list") - if (tabnames){ - images_history_tab_list = tabnames.querySelector("textarea").value.split(",") - for (var i in images_history_tab_list ){ - var tab = images_history_tab_list[i]; - gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor"); - gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index"); - gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button"); - gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery"); - gradioApp().getElementById(tab + "_images_history_start").setAttribute("style","padding:20px;font-size:25px"); - } - - //preload - if (gradioApp().getElementById("images_history_preload").querySelector("input").checked ){ - var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div"); - tabs_box.setAttribute("id", "images_history_tab"); - var tab_btns = tabs_box.querySelectorAll("button"); - for (var i in images_history_tab_list){ - var tabname = images_history_tab_list[i] - tab_btns[i].setAttribute("tabname", tabname); - tab_btns[i].addEventListener('click', function(){ - var tabs_box = gradioApp().getElementById("images_history_tab"); - if (!tabs_box.classList.contains(this.getAttribute("tabname"))) { - gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click(); - tabs_box.classList.add(this.getAttribute("tabname")) - } - }); - } - tab_btns[0].click() - } - } else { - setTimeout(images_history_init, 500); - } -} - -var images_history_tab_list = ""; -setTimeout(images_history_init, 500); -document.addEventListener("DOMContentLoaded", function() { - var mutationObserver = new MutationObserver(function(m){ - if (images_history_tab_list != ""){ - for (var i in images_history_tab_list ){ - let tabname = images_history_tab_list[i] - var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item'); - buttons.forEach(function(bnt){ - bnt.addEventListener('click', images_history_click_image, true); - }); - - var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); - if (cls_btn){ - cls_btn.addEventListener('click', function(){ - gradioApp().getElementById(tabname + '_images_history_renew_page').click(); - }, false); - } - - } - } - }); - mutationObserver.observe(gradioApp(), { childList:true, subtree:true }); -}); - - diff --git a/javascript/inspiration.js b/javascript/inspiration.js deleted file mode 100644 index 39844544..00000000 --- a/javascript/inspiration.js +++ /dev/null @@ -1,48 +0,0 @@ -function public_image_index_in_gallery(item, gallery){ - var imgs = gallery.querySelectorAll("img.h-full") - var index; - var i = 0; - imgs.forEach(function(e){ - if (e == item) - index = i; - i += 1; - }); - var all_imgs = gallery.querySelectorAll("img") - if (all_imgs.length > imgs.length){ - var num = imgs.length / 2 - index = (index < num) ? index : (index - num) - } - return index; -} - -function inspiration_selected(name, name_list){ - var btn = gradioApp().getElementById("inspiration_select_button") - return [gradioApp().getElementById("inspiration_select_button").getAttribute("img-index")]; -} - -function inspiration_click_get_button(){ - gradioApp().getElementById("inspiration_get_button").click(); -} - -var inspiration_image_click = function(){ - var index = public_image_index_in_gallery(this, gradioApp().getElementById("inspiration_gallery")); - var btn = gradioApp().getElementById("inspiration_select_button"); - btn.setAttribute("img-index", index); - setTimeout(function(btn){btn.click();}, 10, btn); -} - -document.addEventListener("DOMContentLoaded", function() { - var mutationObserver = new MutationObserver(function(m){ - var gallery = gradioApp().getElementById("inspiration_gallery") - if (gallery) { - var node = gallery.querySelector(".absolute.backdrop-blur.h-full") - if (node) { - node.style.display = "None"; - } - gallery.querySelectorAll('img').forEach(function(e){ - e.onclick = inspiration_image_click - }); - } - }); - mutationObserver.observe( gradioApp(), { childList:true, subtree:true }); -}); diff --git a/modules/images_history.py b/modules/images_history.py deleted file mode 100644 index bc5cf11f..00000000 --- a/modules/images_history.py +++ /dev/null @@ -1,424 +0,0 @@ -import os -import shutil -import time -import hashlib -import gradio -system_bak_path = "webui_log_and_bak" -custom_tab_name = "custom fold" -faverate_tab_name = "favorites" -tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name] -def is_valid_date(date): - try: - time.strptime(date, "%Y%m%d") - return True - except: - return False - -def reduplicative_file_move(src, dst): - def same_name_file(basename, path): - name, ext = os.path.splitext(basename) - f_list = os.listdir(path) - max_num = 0 - for f in f_list: - if len(f) <= len(basename): - continue - f_ext = f[-len(ext):] if len(ext) > 0 else "" - if f[:len(name)] == name and f_ext == ext: - if f[len(name)] == "(" and f[-len(ext)-1] == ")": - number = f[len(name)+1:-len(ext)-1] - if number.isdigit(): - if int(number) > max_num: - max_num = int(number) - return f"{name}({max_num + 1}){ext}" - name = os.path.basename(src) - save_name = os.path.join(dst, name) - if not os.path.exists(save_name): - shutil.move(src, dst) - else: - name = same_name_file(name, dst) - shutil.move(src, os.path.join(dst, name)) - -def traverse_all_files(curr_path, image_list, all_type=False): - try: - f_list = os.listdir(curr_path) - except: - if all_type or (curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt" and curr_path[-4:] != ".csv"): - image_list.append(curr_path) - return image_list - for file in f_list: - file = os.path.join(curr_path, file) - if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"): - pass - elif os.path.isfile(file) and file[-10:].rfind(".") > 0: - image_list.append(file) - else: - image_list = traverse_all_files(file, image_list) - return image_list - -def auto_sorting(dir_name): - bak_path = os.path.join(dir_name, system_bak_path) - if not os.path.exists(bak_path): - os.mkdir(bak_path) - log_file = None - files_list = [] - f_list = os.listdir(dir_name) - for file in f_list: - if file == system_bak_path: - continue - file_path = os.path.join(dir_name, file) - if not is_valid_date(file): - if file[-10:].rfind(".") > 0: - files_list.append(file_path) - else: - files_list = traverse_all_files(file_path, files_list, all_type=True) - - for file in files_list: - date_str = time.strftime("%Y%m%d",time.localtime(os.path.getmtime(file))) - file_path = os.path.dirname(file) - hash_path = hashlib.md5(file_path.encode()).hexdigest() - path = os.path.join(dir_name, date_str, hash_path) - if not os.path.exists(path): - os.makedirs(path) - if log_file is None: - log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a") - log_file.write(f"{hash_path},{file_path}\n") - reduplicative_file_move(file, path) - - date_list = [] - f_list = os.listdir(dir_name) - for f in f_list: - if is_valid_date(f): - date_list.append(f) - elif f == system_bak_path: - continue - else: - try: - reduplicative_file_move(os.path.join(dir_name, f), bak_path) - except: - pass - - today = time.strftime("%Y%m%d",time.localtime(time.time())) - if today not in date_list: - date_list.append(today) - return sorted(date_list, reverse=True) - -def archive_images(dir_name, date_to): - filenames = [] - batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num) - if batch_size <= 0: - batch_size = opts.images_history_num_per_page * 6 - today = time.strftime("%Y%m%d",time.localtime(time.time())) - date_to = today if date_to is None or date_to == "" else date_to - date_to_bak = date_to - if False: #opts.images_history_reconstruct_directory: - date_list = auto_sorting(dir_name) - for date in date_list: - if date <= date_to: - path = os.path.join(dir_name, date) - if date == today and not os.path.exists(path): - continue - filenames = traverse_all_files(path, filenames) - if len(filenames) > batch_size: - break - filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file)) - else: - filenames = traverse_all_files(dir_name, filenames) - total_num = len(filenames) - tmparray = [(os.path.getmtime(file), file) for file in filenames ] - date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 - filenames = [] - date_list = {date_to:None} - date = time.strftime("%Y%m%d",time.localtime(time.time())) - for t, f in tmparray: - date = time.strftime("%Y%m%d",time.localtime(t)) - date_list[date] = None - if t <= date_stamp: - filenames.append((t, f ,date)) - date_list = sorted(list(date_list.keys()), reverse=True) - sort_array = sorted(filenames, key=lambda x:-x[0]) - if len(sort_array) > batch_size: - date = sort_array[batch_size][2] - filenames = [x[1] for x in sort_array] - else: - date = date_to if len(sort_array) == 0 else sort_array[-1][2] - filenames = [x[1] for x in sort_array] - filenames = [x[1] for x in sort_array if x[2]>= date] - num = len(filenames) - last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) - date = date[:4] + "/" + date[4:6] + "/" + date[6:8] - date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8] - load_info = "
" - load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" - load_info += "
" - _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) - return ( - date_to, - load_info, - filenames, - 1, - image_list, - "", - "", - visible_num, - last_date_from, - gradio.update(visible=total_num > num) - ) - -def delete_image(delete_num, name, filenames, image_index, visible_num): - if name == "": - return filenames, delete_num - else: - delete_num = int(delete_num) - visible_num = int(visible_num) - image_index = int(image_index) - index = list(filenames).index(name) - i = 0 - new_file_list = [] - for name in filenames: - if i >= index and i < index + delete_num: - if os.path.exists(name): - if visible_num == image_index: - new_file_list.append(name) - i += 1 - continue - print(f"Delete file {name}") - os.remove(name) - visible_num -= 1 - txt_file = os.path.splitext(name)[0] + ".txt" - if os.path.exists(txt_file): - os.remove(txt_file) - else: - print(f"Not exists file {name}") - else: - new_file_list.append(name) - i += 1 - return new_file_list, 1, visible_num - -def save_image(file_name): - if file_name is not None and os.path.exists(file_name): - shutil.copy(file_name, opts.outdir_save) - -def get_recent_images(page_index, step, filenames): - page_index = int(page_index) - num_of_imgs_per_page = int(opts.images_history_num_per_page) - max_page_index = len(filenames) // num_of_imgs_per_page + 1 - page_index = max_page_index if page_index == -1 else page_index + step - page_index = 1 if page_index < 1 else page_index - page_index = max_page_index if page_index > max_page_index else page_index - idx_frm = (page_index - 1) * num_of_imgs_per_page - image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page] - length = len(filenames) - visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page - visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num - return page_index, image_list, "", "", visible_num - -def loac_batch_click(date_to): - if date_to is None: - return time.strftime("%Y%m%d",time.localtime(time.time())), [] - else: - return None, [] -def forward_click(last_date_from, date_to_recorder): - if len(date_to_recorder) == 0: - return None, [] - if last_date_from == date_to_recorder[-1]: - date_to_recorder = date_to_recorder[:-1] - if len(date_to_recorder) == 0: - return None, [] - return date_to_recorder[-1], date_to_recorder[:-1] - -def backward_click(last_date_from, date_to_recorder): - if last_date_from is None or last_date_from == "": - return time.strftime("%Y%m%d",time.localtime(time.time())), [] - if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]: - date_to_recorder.append(last_date_from) - return last_date_from, date_to_recorder - - -def first_page_click(page_index, filenames): - return get_recent_images(1, 0, filenames) - -def end_page_click(page_index, filenames): - return get_recent_images(-1, 0, filenames) - -def prev_page_click(page_index, filenames): - return get_recent_images(page_index, -1, filenames) - -def next_page_click(page_index, filenames): - return get_recent_images(page_index, 1, filenames) - -def page_index_change(page_index, filenames): - return get_recent_images(page_index, 0, filenames) - -def show_image_info(tabname_box, num, page_index, filenames): - file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] - tm = "
" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "
" - return file, tm, num, file - -def enable_page_buttons(): - return gradio.update(visible=True) - -def change_dir(img_dir, date_to): - warning = None - try: - if os.path.exists(img_dir): - try: - f = os.listdir(img_dir) - except: - warning = f"'{img_dir} is not a directory" - else: - warning = "The directory is not exist" - except: - warning = "The format of the directory is incorrect" - if warning is None: - today = time.strftime("%Y%m%d",time.localtime(time.time())) - return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True) - else: - return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False) - -def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): - custom_dir = False - if tabname == "txt2img": - dir_name = opts.outdir_txt2img_samples - elif tabname == "img2img": - dir_name = opts.outdir_img2img_samples - elif tabname == "extras": - dir_name = opts.outdir_extras_samples - elif tabname == faverate_tab_name: - dir_name = opts.outdir_save - else: - custom_dir = True - dir_name = None - - if not custom_dir: - d = dir_name.split("/") - dir_name = d[0] - for p in d[1:]: - dir_name = os.path.join(dir_name, p) - if not os.path.exists(dir_name): - os.makedirs(dir_name) - - with gr.Column() as page_panel: - with gr.Row(): - with gr.Column(scale=1, visible=not custom_dir) as load_batch_box: - load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True) - with gr.Column(scale=4): - with gr.Row(): - img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) - with gr.Row(): - with gr.Column(visible=False, scale=1) as batch_panel: - with gr.Row(): - forward = gr.Button('Prev batch') - backward = gr.Button('Next batch') - with gr.Column(scale=3): - load_info = gr.HTML(visible=not custom_dir) - with gr.Row(visible=False) as warning: - warning_box = gr.Textbox("Message", interactive=False) - - with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: - with gr.Column(scale=2): - with gr.Row(visible=True) as turn_page_buttons: - #date_to = gr.Dropdown(label="Date to") - first_page = gr.Button('First Page') - prev_page = gr.Button('Prev Page') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next Page') - end_page = gr.Button('End Page') - - history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num) - with gr.Row(): - delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") - delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - - with gr.Column(): - with gr.Row(): - with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6) - gr.HTML("
") - img_file_name = gr.Textbox(value="", label="File Name", interactive=False) - img_file_time= gr.HTML() - with gr.Row(): - if tabname != faverate_tab_name: - save_btn = gr.Button('Collect') - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') - - - # hiden items - with gr.Row(visible=False): - renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") - batch_date_to = gr.Textbox(label="Date to") - visible_img_num = gr.Number() - date_to_recorder = gr.State([]) - last_date_from = gr.Textbox() - tabname_box = gr.Textbox(tabname) - image_index = gr.Textbox(value=-1) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") - filenames = gr.State() - all_images_list = gr.State() - hidden = gr.Image(type="pil") - info1 = gr.Textbox() - info2 = gr.Textbox() - - img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info]) - - #change batch - change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel] - - batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output) - batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) - batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - - load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder]) - forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) - backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) - - - #delete - delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) - delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None) - if tabname != faverate_tab_name: - save_btn.click(save_image, inputs=[img_file_name], outputs=None) - - #turn page - gallery_inputs = [page_index, filenames] - gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num] - first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) - renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) - - first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - - # other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden]) - img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) - hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) - switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') - switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') - - - -def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict): - global opts; - opts = sys_opts - loads_files_num = int(opts.images_history_num_per_page) - num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) - if cmp_ops.browse_all_images: - tabs_list.append(custom_tab_name) - with gr.Blocks(analytics_enabled=False) as images_history: - with gr.Tabs() as tabs: - for tab in tabs_list: - with gr.Tab(tab): - with gr.Blocks(analytics_enabled=False) : - show_images_history(gr, opts, tab, run_pnginfo, switch_dict) - gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False) - gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False) - - return images_history diff --git a/modules/inspiration.py b/modules/inspiration.py deleted file mode 100644 index 29cf8297..00000000 --- a/modules/inspiration.py +++ /dev/null @@ -1,193 +0,0 @@ -import os -import random -import gradio -from modules.shared import opts -inspiration_system_path = os.path.join(opts.inspiration_dir, "system") -def read_name_list(file, types=None, keyword=None): - if not os.path.exists(file): - return [] - ret = [] - f = open(file, "r") - line = f.readline() - while len(line) > 0: - line = line.rstrip("\n") - if types is not None: - dirname = os.path.split(line) - if dirname[0] in types and keyword in dirname[1].lower(): - ret.append(line) - else: - ret.append(line) - line = f.readline() - return ret - -def save_name_list(file, name): - name_list = read_name_list(file) - if name not in name_list: - with open(file, "a") as f: - f.write(name + "\n") - -def get_types_list(): - files = os.listdir(opts.inspiration_dir) - types = [] - for x in files: - path = os.path.join(opts.inspiration_dir, x) - if x[0] == ".": - continue - if not os.path.isdir(path): - continue - if path == inspiration_system_path: - continue - types.append(x) - return types - -def get_inspiration_images(source, types, keyword): - keyword = keyword.strip(" ").lower() - get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num) - if source == "Favorites": - names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword) - names = random.sample(names, get_num) if len(names) > get_num else names - elif source == "Abandoned": - names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) - names = random.sample(names, get_num) if len(names) > get_num else names - elif source == "Exclude abandoned": - abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) - all_names = [] - for tp in types: - name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) - all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()] - - if len(all_names) > get_num: - names = [] - while len(names) < get_num: - name = random.choice(all_names) - if name not in abandoned: - names.append(name) - else: - names = all_names - else: - all_names = [] - for tp in types: - name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) - all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()] - names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names - image_list = [] - for a in names: - image_path = os.path.join(opts.inspiration_dir, a) - images = os.listdir(image_path) - if len(images) > 0: - image_list.append((os.path.join(image_path, random.choice(images)), a)) - else: - print(image_path) - return image_list, names - -def select_click(index, name_list): - name = name_list[int(index)] - path = os.path.join(opts.inspiration_dir, name) - images = os.listdir(path) - return name, [os.path.join(path, x) for x in images], "" - -def give_up_click(name): - file = os.path.join(inspiration_system_path, "abandoned.txt") - save_name_list(file, name) - return "Added to abandoned list" - -def collect_click(name): - file = os.path.join(inspiration_system_path, "faverites.txt") - save_name_list(file, name) - return "Added to faverite list" - -def moveout_click(name, source): - if source == "Abandoned": - file = os.path.join(inspiration_system_path, "abandoned.txt") - elif source == "Favorites": - file = os.path.join(inspiration_system_path, "faverites.txt") - else: - return None - name_list = read_name_list(file) - os.remove(file) - with open(file, "a") as f: - for a in name_list: - if a != name: - f.write(a + "\n") - return f"Moved out {name} from {source} list" - -def source_change(source): - if source in ["Abandoned", "Favorites"]: - return gradio.update(visible=True), [] - else: - return gradio.update(visible=False), [] -def add_to_prompt(name, prompt): - name = os.path.basename(name) - return prompt + "," + name - -def clear_keyword(): - return "" - -def ui(gr, opts, txt2img_prompt, img2img_prompt): - with gr.Blocks(analytics_enabled=False) as inspiration: - flag = os.path.exists(opts.inspiration_dir) - if flag: - types = get_types_list() - flag = len(types) > 0 - else: - os.makedirs(opts.inspiration_dir) - if not flag: - gr.HTML(""" -

To activate inspiration function, you need get "inspiration" images first.


- You can create these images by run "Create inspiration images" script in txt2img page,
you can get the artists or art styles list from here
- https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
- download these files, and select these files in the "Create inspiration images" script UI
- There about 6000 artists and art styles in these files.
This takes server hours depending on your GPU type and how many pictures you generate for each artist/style -
I suggest at least four images for each


-

You can also download generated pictures from here:


- https://huggingface.co/datasets/yfszzx/inspiration
- unzip the file to the project directory of webui
- and restart webui, and enjoy the joy of creation!
- """) - return inspiration - if not os.path.exists(inspiration_system_path): - os.mkdir(inspiration_system_path) - with gr.Row(): - with gr.Column(scale=2): - inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto') - with gr.Column(scale=1): - types = gr.CheckboxGroup(choices=types, value=types) - with gr.Row(): - source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") - keyword = gr.Textbox("", label="Key word") - get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button") - name = gr.Textbox(show_label=False, interactive=False) - with gr.Row(): - send_to_txt2img = gr.Button('to txt2img') - send_to_img2img = gr.Button('to img2img') - collect = gr.Button('Collect') - give_up = gr.Button("Don't show again") - moveout = gr.Button("Move out", visible=False) - warning = gr.HTML() - style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') - - - - with gr.Row(visible=False): - select_button = gr.Button('set button', elem_id="inspiration_select_button") - name_list = gr.State() - - get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list]) - keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) - source.change(source_change, inputs=[source], outputs=[moveout, style_gallery]) - source.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword]) - types.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword]) - - select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning]) - give_up.click(give_up_click, inputs=[name], outputs=[warning]) - collect.click(collect_click, inputs=[name], outputs=[warning]) - moveout.click(moveout_click, inputs=[name, source], outputs=[warning]) - moveout.click(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) - - send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt]) - send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt]) - send_to_txt2img.click(collect_click, inputs=[name], outputs=[warning]) - send_to_img2img.click(collect_click, inputs=[name], outputs=[warning]) - send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None) - send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None) - return inspiration diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 5bcccd67..66666a56 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,4 +1,3 @@ - callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] @@ -16,7 +15,6 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] - for callback in callbacks_ui_tabs: res += callback() or [] diff --git a/modules/shared.py b/modules/shared.py index 0aaaadac..5dfd7927 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -321,21 +321,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) -options_templates.update(options_section(('inspiration', "Inspiration"), { - "inspiration_dir": OptionInfo("inspiration", "Directory of inspiration", component_args=hide_dirs), - "inspiration_max_samples": OptionInfo(4, "Maximum number of samples, used to determine which folders to skip when continue running the create script", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), - "inspiration_rows_num": OptionInfo(4, "Rows of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), - "inspiration_cols_num": OptionInfo(8, "Columns of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), -})) - -options_templates.update(options_section(('images-history', "Images Browser"), { - #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), - "images_history_preload": OptionInfo(False, "Preload images at startup"), - "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"), - "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "), - "images_history_grid_num": OptionInfo(6, "Number of grids in each row"), - -})) class Options: data = None diff --git a/modules/ui.py b/modules/ui.py index a73175f5..fa42712e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -49,14 +49,12 @@ from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.images_history as images_history -import modules.inspiration as inspiration - - # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') +txt2img_paste_fields = [] +img2img_paste_fields = [] if not cmd_opts.share and not cmd_opts.listen: @@ -1193,16 +1191,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[image], outputs=[html, generation_info, html2], ) - #images history - images_history_switch_dict = { - "fn": modules.generation_parameters_copypaste.connect_paste, - "t2i": txt2img_paste_fields, - "i2i": img2img_paste_fields - } - - browser_interface = images_history.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) - inspiration_interface = inspiration.ui(gr, opts, txt2img_prompt, img2img_prompt) - + with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): @@ -1651,8 +1640,6 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (inspiration_interface, "Inspiration", "inspiration"), - (browser_interface , "Image Browser", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), ] @@ -1896,6 +1883,7 @@ def load_javascript(raw_response): javascript = f'' scripts_list = modules.scripts.list_scripts("javascript", ".js") + scripts_list += modules.scripts.list_scripts("scripts", ".js") for basedir, filename, path in scripts_list: with open(path, "r", encoding="utf8") as jsfile: javascript += f"\n" -- cgit v1.2.3 From cef1b89aa2e6c7647db7e93a4cd4ec020da3f2da Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 24 Oct 2022 10:10:33 +0800 Subject: remove browser to extension --- modules/script_callbacks.py | 2 ++ modules/shared.py | 1 - modules/ui.py | 2 +- scripts/create_inspiration_images.py | 57 ------------------------------------ 4 files changed, 3 insertions(+), 59 deletions(-) delete mode 100644 scripts/create_inspiration_images.py (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 66666a56..f46d3d9a 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,3 +1,4 @@ + callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] @@ -15,6 +16,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] + for callback in callbacks_ui_tabs: res += callback() or [] diff --git a/modules/shared.py b/modules/shared.py index 5dfd7927..6541e679 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -82,7 +82,6 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) -parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False) cmd_opts = parser.parse_args() restricted_opts = [ diff --git a/modules/ui.py b/modules/ui.py index fa42712e..a32f7259 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1104,7 +1104,7 @@ def create_ui(wrap_gradio_gpu_call): upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers] , value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") diff --git a/scripts/create_inspiration_images.py b/scripts/create_inspiration_images.py deleted file mode 100644 index 2fd30578..00000000 --- a/scripts/create_inspiration_images.py +++ /dev/null @@ -1,57 +0,0 @@ -import csv, os, shutil -import modules.scripts as scripts -from modules import processing, shared, sd_samplers, images -from modules.processing import Processed -from modules.shared import opts -import gradio -class Script(scripts.Script): - def title(self): - return "Create inspiration images" - - def show(self, is_img2img): - return True - - def ui(self, is_img2img): - file = gradio.Files(label="Artist or styles name list. '.txt' files with one name per line",) - with gradio.Row(): - prefix = gradio.Textbox("a painting in", label="Prompt words before artist or style name", file_count="multiple") - suffix= gradio.Textbox("style", label="Prompt words after artist or style name") - negative_prompt = gradio.Textbox("picture frame, portrait photo", label="Negative Prompt") - with gradio.Row(): - batch_size = gradio.Number(1, label="Batch size") - batch_count = gradio.Number(2, label="Batch count") - return [batch_size, batch_count, prefix, suffix, negative_prompt, file] - - def run(self, p, batch_size, batch_count, prefix, suffix, negative_prompt, files): - p.batch_size = int(batch_size) - p.n_iterint = int(batch_count) - p.negative_prompt = negative_prompt - p.do_not_save_samples = True - p.do_not_save_grid = True - for file in files: - tp = file.orig_name.split(".")[0] - print(tp) - path = os.path.join(opts.inspiration_dir, tp) - if not os.path.exists(path): - os.makedirs(path) - f = open(file.name, "r") - line = f.readline() - while len(line) > 0: - name = line.rstrip("\n").split(",")[0] - line = f.readline() - artist_path = os.path.join(path, name) - if not os.path.exists(artist_path): - os.mkdir(artist_path) - if len(os.listdir(artist_path)) >= opts.inspiration_max_samples: - continue - p.prompt = f"{prefix} {name} {suffix}" - print(p.prompt) - processed = processing.process_images(p) - for img in processed.images: - i = 0 - filename = os.path.join(artist_path, format(0, "03d") + ".jpg") - while os.path.exists(filename): - i += 1 - filename = os.path.join(artist_path, format(i, "03d") + ".jpg") - img.save(filename, quality=80) - return processed -- cgit v1.2.3 From a889c93f23f1e80d0dac4e5ddbc3a26207e8cdf1 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 24 Oct 2022 11:13:16 +0800 Subject: paste_fields add to public --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a32f7259..a73b9ff0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -784,6 +784,7 @@ def create_ui(wrap_gradio_gpu_call): ] ) + global txt2img_paste_fields txt2img_paste_fields = [ (txt2img_prompt, "Prompt"), (txt2img_negative_prompt, "Negative prompt"), @@ -1054,6 +1055,7 @@ def create_ui(wrap_gradio_gpu_call): outputs=[prompt, negative_prompt, style1, style2], ) + global img2img_paste_fields img2img_paste_fields = [ (img2img_prompt, "Prompt"), (img2img_negative_prompt, "Negative prompt"), -- cgit v1.2.3 From 974196932583b96b6b76632052fc0d7e70820bf3 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sun, 23 Oct 2022 22:38:42 +0300 Subject: Save properly processed image before color correction --- modules/processing.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index ff83023c..15b639e1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -46,6 +46,20 @@ def apply_color_correction(correction, image): return image +def apply_overlay(overlay_exists, overlay, paste_loc, image): + if overlay_exists: + if paste_loc is not None: + x, y, w, h = paste_loc + base_image = Image.new('RGBA', (overlay.width, overlay.height)) + image = images.resize_image(1, image, w, h) + base_image.paste(image, (x, y)) + image = base_image + + image = image.convert('RGBA') + image.alpha_composite(overlay) + image = image.convert('RGB') + + return image def get_correct_sampler(p): if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img): @@ -446,25 +460,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() image = Image.fromarray(x_sample) - + if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: - images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") + image_without_cc = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image) + images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) - if p.overlay_images is not None and i < len(p.overlay_images): - overlay = p.overlay_images[i] - - if p.paste_to is not None: - x, y, w, h = p.paste_to - base_image = Image.new('RGBA', (overlay.width, overlay.height)) - image = images.resize_image(1, image, w, h) - base_image.paste(image, (x, y)) - image = base_image - - image = image.convert('RGBA') - image.alpha_composite(overlay) - image = image.convert('RGB') + image = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image) if opts.samples_save and not p.do_not_save_samples: images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) -- cgit v1.2.3 From f2cc3f32d5bc8538e95edec54d7dc1b9efdf769a Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sun, 23 Oct 2022 22:44:46 +0300 Subject: fix whitespaces --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 15b639e1..2a332514 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -460,7 +460,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() image = Image.fromarray(x_sample) - + if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: image_without_cc = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image) -- cgit v1.2.3 From b297cc3324979ec78d69b2d11dd18030dfad7bcc Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 20:06:42 +0900 Subject: Hypernetworks - fix KeyError in statistics caching Statistics logging has changed to {filename : list[losses]}, so it has to use loss_info[key].pop() --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 98a7b62e..33827210 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -274,8 +274,8 @@ def log_statistics(loss_info:dict, key, value): loss_info[key] = [value] else: loss_info[key].append(value) - if len(loss_info) > 1024: - loss_info.pop(0) + if len(loss_info[key]) > 1024: + loss_info[key].pop(0) def statistics(data): -- cgit v1.2.3 From 40b56c9289bf9458ae5ef3c1990ccea851c6c3e2 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:07:07 +0900 Subject: cleanup some code --- modules/hypernetworks/hypernetwork.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 33827210..4072bf54 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum +from collections import defaultdict, deque from statistics import stdev, mean class HypernetworkModule(torch.nn.Module): @@ -269,15 +270,6 @@ def stack_conds(conds): return torch.stack(conds) -def log_statistics(loss_info:dict, key, value): - if key not in loss_info: - loss_info[key] = [value] - else: - loss_info[key].append(value) - if len(loss_info[key]) > 1024: - loss_info[key].pop(0) - - def statistics(data): total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})" recent_data = data[-32:] @@ -341,7 +333,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log weight.requires_grad = True size = len(ds.indexes) - loss_dict = {} + loss_dict = defaultdict(lambda : deque(maxlen = 1024)) losses = torch.zeros((size,)) previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) @@ -383,7 +375,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item() for entry in entries: - log_statistics(loss_dict, entry.filename, loss.item()) + loss_dict[entry.filename].append(loss.item()) optimizer.zero_grad() weights[0].grad = None -- cgit v1.2.3 From 348f89c8d40397c1875cff4a7331018785f9c3b8 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:29:53 +0900 Subject: statistics for pbar --- modules/hypernetworks/hypernetwork.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4072bf54..48b56029 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -335,6 +335,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log size = len(ds.indexes) loss_dict = defaultdict(lambda : deque(maxlen = 1024)) losses = torch.zeros((size,)) + previous_mean_losses = [0] previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) @@ -356,7 +357,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log for i, entries in pbar: hypernetwork.step = i + ititial_step if len(loss_dict) > 0: - previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict) + previous_mean_losses = [i[-1] for i in loss_dict.values()] + previous_mean_loss = mean(previous_mean_losses) scheduler.apply(optimizer, hypernetwork.step) if scheduler.finished: @@ -391,7 +393,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): raise RuntimeError("Loss diverged.") - pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}") + + if len(previous_mean_losses) > 1: + std = stdev(previous_mean_losses) + else: + std = 0 + dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" + pbar.set_description(dataset_loss_info) if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. -- cgit v1.2.3 From 0d2e1dac407a0e2f5b148d314715f0457b2525b7 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:41:39 +0900 Subject: convert deque -> list I don't feel this being efficient --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 48b56029..fb510fa7 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -282,7 +282,7 @@ def report_statistics(loss_info:dict): for key in keys: try: print("Loss statistics for file " + key) - info, recent = statistics(loss_info[key]) + info, recent = statistics(list(loss_info[key])) print(info) print(recent) except Exception as e: -- cgit v1.2.3 From e9a410b5357612f63528015c5533c2185dcff92e Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 21:47:39 +0900 Subject: check length for variance --- modules/hypernetworks/hypernetwork.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index fb510fa7..d647ea55 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -271,9 +271,17 @@ def stack_conds(conds): def statistics(data): - total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})" + if len(data) < 2: + std = 0 + else: + std = stdev(data) + total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})" recent_data = data[-32:] - recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})" + if len(recent_data) < 2: + std = 0 + else: + std = stdev(recent_data) + recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})" return total_information, recent_information -- cgit v1.2.3 From 6cbb04f7a5e675cf1f6dfc247aa9c9e8df7dc5ce Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 24 Oct 2022 09:15:26 +0300 Subject: fix #3517 breaking txt2img --- modules/processing.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 2a332514..c61bbfbd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -46,18 +46,23 @@ def apply_color_correction(correction, image): return image -def apply_overlay(overlay_exists, overlay, paste_loc, image): - if overlay_exists: - if paste_loc is not None: - x, y, w, h = paste_loc - base_image = Image.new('RGBA', (overlay.width, overlay.height)) - image = images.resize_image(1, image, w, h) - base_image.paste(image, (x, y)) - image = base_image - - image = image.convert('RGBA') - image.alpha_composite(overlay) - image = image.convert('RGB') + +def apply_overlay(image, paste_loc, index, overlays): + if overlays is None or index >= len(overlays): + return image + + overlay = overlays[index] + + if paste_loc is not None: + x, y, w, h = paste_loc + base_image = Image.new('RGBA', (overlay.width, overlay.height)) + image = images.resize_image(1, image, w, h) + base_image.paste(image, (x, y)) + image = base_image + + image = image.convert('RGBA') + image.alpha_composite(overlay) + image = image.convert('RGB') return image @@ -463,11 +468,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: - image_without_cc = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image) + image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) - image = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image) + image = apply_overlay(image, p.paste_to, i, p.overlay_images) if opts.samples_save and not p.do_not_save_samples: images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) -- cgit v1.2.3 From 734986dde3231416813f827242c111da212b2ccb Mon Sep 17 00:00:00 2001 From: Trung Ngo Date: Mon, 24 Oct 2022 01:17:09 -0500 Subject: add callback after image is saved --- modules/images.py | 3 ++- modules/script_callbacks.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index b9589563..01c60f89 100644 --- a/modules/images.py +++ b/modules/images.py @@ -12,7 +12,7 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from fonts.ttf import Roboto import string -from modules import sd_samplers, shared +from modules import sd_samplers, shared, script_callbacks from modules.shared import opts, cmd_opts LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) @@ -467,6 +467,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else: txt_fullfn = None + script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn) return fullfn, txt_fullfn diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 5bcccd67..5836e4b9 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,11 +2,12 @@ callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] - +callbacks_image_saved = [] def clear_callbacks(): callbacks_model_loaded.clear() callbacks_ui_tabs.clear() + callbacks_image_saved.clear() def model_loaded_callback(sd_model): @@ -28,6 +29,10 @@ def ui_settings_callback(): callback() +def image_saved_callback(image, p, fullfn, txt_fullfn): + for callback in callbacks_image_saved: + callback(image, p, fullfn, txt_fullfn) + def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" @@ -51,3 +56,8 @@ def on_ui_settings(callback): """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ callbacks_ui_settings.append(callback) + + +def on_save_imaged(callback): + """register a function to call after modules.images.save_image is called returning same values, original image and p """ + callbacks_image_saved.append(callback) -- cgit v1.2.3 From 876a96f0f9843382ebc8984db3de5d8af0e9ce4c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 24 Oct 2022 09:39:46 +0300 Subject: remove erroneous dir in the extension directory remove loading .js files from scripts dir (they go into javascript) load scripts after models, for scripts that depend on loaded models --- extensions/stable-diffusion-webui-inspiration | 1 - modules/ui.py | 2 +- webui.py | 11 ++++++----- 3 files changed, 7 insertions(+), 7 deletions(-) delete mode 160000 extensions/stable-diffusion-webui-inspiration (limited to 'modules') diff --git a/extensions/stable-diffusion-webui-inspiration b/extensions/stable-diffusion-webui-inspiration deleted file mode 160000 index a0b96664..00000000 --- a/extensions/stable-diffusion-webui-inspiration +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a0b96664d2524b87916ae463fbb65411b13a569b diff --git a/modules/ui.py b/modules/ui.py index a73b9ff0..03528968 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1885,7 +1885,7 @@ def load_javascript(raw_response): javascript = f'' scripts_list = modules.scripts.list_scripts("javascript", ".js") - scripts_list += modules.scripts.list_scripts("scripts", ".js") + for basedir, filename, path in scripts_list: with open(path, "r", encoding="utf8") as jsfile: javascript += f"\n" diff --git a/webui.py b/webui.py index a0f3757f..ade7334b 100644 --- a/webui.py +++ b/webui.py @@ -9,7 +9,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import devices, sd_samplers +from modules import devices, sd_samplers, upscaler import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -73,12 +73,11 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def initialize(): - modules.scripts.load_scripts() if cmd_opts.ui_debug_mode: - class enmpty(): - name = None - shared.sd_upscalers = [enmpty()] + shared.sd_upscalers = upscaler.UpscalerLanczos().scalers + modules.scripts.load_scripts() return + modelloader.cleanup_models() modules.sd_models.setup_model() codeformer.setup_model(cmd_opts.codeformer_models_path) @@ -86,6 +85,8 @@ def initialize(): shared.face_restorers.append(modules.face_restoration.FaceRestoration()) modelloader.load_upscalers() + modules.scripts.load_scripts() + modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) -- cgit v1.2.3 From 3be6b29d81408d2adb741bff5b11c80214aa621e Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 24 Oct 2022 15:14:34 +0900 Subject: indent=4 config.json --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 6541e679..d6ddfe59 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -348,7 +348,7 @@ class Options: def save(self, filename): with open(filename, "w", encoding="utf8") as file: - json.dump(self.data, file) + json.dump(self.data, file, indent=4) def same_type(self, x, y): if x is None or y is None: -- cgit v1.2.3 From c5d90628a4058bf49c2fdabf620a24db73407f31 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 22 Oct 2022 17:16:55 +0900 Subject: move "file_decoration" initialize section into "if forced_filename is None:" no need to initialize it if it's not going to be used --- modules/images.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index b9589563..50a59cff 100644 --- a/modules/images.py +++ b/modules/images.py @@ -386,18 +386,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i txt_fullfn (`str` or None): If a text file is saved for this image, this will be its full path. Otherwise None. ''' - if short_filename or prompt is None or seed is None: - file_decoration = "" - elif opts.save_to_dirs: - file_decoration = opts.samples_filename_pattern or "[seed]" - else: - file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]" - - if file_decoration != "": - file_decoration = "-" + file_decoration.lower() - - file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix - if extension == 'png' and opts.enable_pnginfo and info is not None: pnginfo = PngImagePlugin.PngInfo() @@ -419,6 +407,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i os.makedirs(path, exist_ok=True) if forced_filename is None: + if short_filename or prompt is None or seed is None: + file_decoration = "" + elif opts.save_to_dirs: + file_decoration = opts.samples_filename_pattern or "[seed]" + else: + file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]" + + if file_decoration != "": + file_decoration = "-" + file_decoration.lower() + + file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix + basecount = get_next_sequence_number(path, basename) fullfn = "a.png" fullfn_without_extension = "a" -- cgit v1.2.3 From 7d4a4db9ea7543c079f4a4a702c2945f4b66cd11 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 22 Oct 2022 17:48:59 +0900 Subject: modify unnecessary sting assignment as it's going to get overwritten --- modules/images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 50a59cff..cc5066b1 100644 --- a/modules/images.py +++ b/modules/images.py @@ -420,8 +420,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix basecount = get_next_sequence_number(path, basename) - fullfn = "a.png" - fullfn_without_extension = "a" + fullfn = None + fullfn_without_extension = None for i in range(500): fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}" fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}") -- cgit v1.2.3 From 37dd6deafb831a809eaf7ae8d232937a8c7998e7 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 22 Oct 2022 21:11:15 +0900 Subject: filename pattern [datetime], extended customizable Format and Time Zone format: [datetime] [datetime] [datetime
+
+
+{versions} +
diff --git a/launch.py b/launch.py index af0d418b..49b91b1f 100644 --- a/launch.py +++ b/launch.py @@ -13,6 +13,21 @@ dir_extensions = "extensions" python = sys.executable git = os.environ.get('GIT', "git") index_url = os.environ.get('INDEX_URL', "") +stored_commit_hash = None + + +def commit_hash(): + global stored_commit_hash + + if stored_commit_hash is not None: + return stored_commit_hash + + try: + stored_commit_hash = run(f"{git} rev-parse HEAD").strip() + except Exception: + stored_commit_hash = "" + + return stored_commit_hash def extract_arg(args, name): @@ -194,10 +209,7 @@ def prepare_environment(): xformers = '--xformers' in sys.argv ngrok = '--ngrok' in sys.argv - try: - commit = run(f"{git} rev-parse HEAD").strip() - except Exception: - commit = "" + commit = commit_hash() print(f"Python {sys.version}") print(f"Commit hash: {commit}") diff --git a/modules/ui.py b/modules/ui.py index bb64fe20..81d96c5b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1696,7 +1696,9 @@ def create_ui(): if os.path.exists("html/footer.html"): with open("html/footer.html", encoding="utf8") as file: - gr.HTML(file.read(), elem_id="footer") + footer = file.read() + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) settings_submit.click( @@ -1857,3 +1859,30 @@ def reload_javascript(): if not hasattr(shared, 'GradioTemplateResponseOriginal'): shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse + + +def versions_html(): + import torch + import launch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = launch.commit_hash() + short_commit = commit[0:8] + + if shared.xformers_available: + import xformers + xformers_version = xformers.__version__ + else: + xformers_version = "N/A" + + return f""" +python: {python_version} + •  +torch: {torch.__version__} + •  +xformers: {xformers_version} + •  +gradio: {gr.__version__} + •  +commit: {short_commit} +""" diff --git a/style.css b/style.css index 09ee540b..ee74d79e 100644 --- a/style.css +++ b/style.css @@ -628,6 +628,11 @@ footer { display: inline-block; } +#footer .versions{ + font-size: 85%; + opacity: 0.85; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From f8d0cf6a6ec4911559cfecb9a9d1d46b547b38e8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 12:08:11 +0300 Subject: rework #6329 to remove duplicate code and add prevent tab names for showing in ids for scripts that only exist on one tab --- modules/scripts.py | 10 ++++++++++ scripts/custom_code.py | 6 ------ scripts/img2imgalt.py | 6 ------ scripts/loopback.py | 6 ------ scripts/outpainting_mk_2.py | 6 ------ scripts/poor_mans_outpainting.py | 6 ------ scripts/prompt_matrix.py | 6 ------ scripts/prompts_from_file.py | 6 ------ scripts/sd_upscale.py | 6 ------ scripts/xy_grid.py | 5 ----- 10 files changed, 10 insertions(+), 53 deletions(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 722f8685..0c44f191 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,4 +1,5 @@ import os +import re import sys import traceback from collections import namedtuple @@ -128,6 +129,15 @@ class Script: """unused""" return "" + def elem_id(self, item_id): + """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id""" + + need_tabname = self.show(True) == self.show(False) + tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else "" + title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower())) + + return f'script_{tabname}{title}_{item_id}' + current_basedir = paths.script_path diff --git a/scripts/custom_code.py b/scripts/custom_code.py index 9ce1f650..d29113e6 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -3,18 +3,12 @@ import gradio as gr from modules.processing import Processed from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "Custom code" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return cmd_opts.allow_code diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 7555e874..cbdfc6b3 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -16,7 +16,6 @@ import k_diffusion as K from PIL import Image from torch import autocast from einops import rearrange, repeat -import re def find_noise_for_image(p, cond, uncond, cfg_scale, steps): @@ -123,11 +122,6 @@ class Script(scripts.Script): def title(self): return "img2img alternative test" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/loopback.py b/scripts/loopback.py index 4df7b73f..1dab9476 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -8,18 +8,12 @@ from modules import processing, shared, sd_samplers, images from modules.processing import Processed from modules.sd_samplers import samplers from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "Loopback" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index b4a0dc73..0906da6a 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -10,7 +10,6 @@ from PIL import Image, ImageDraw from modules import images, processing, devices from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state -import re # this function is taken from https://github.com/parlance-zz/g-diffuser-bot @@ -123,11 +122,6 @@ class Script(scripts.Script): def title(self): return "Outpainting mk2" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 1c7dc467..d8feda00 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -7,18 +7,12 @@ from PIL import Image, ImageDraw from modules import images, processing, devices from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "Poor man's outpainting" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 278d2e68..dd95e588 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -10,7 +10,6 @@ from modules import images from modules.processing import process_images, Processed from modules.shared import opts, cmd_opts, state import modules.sd_samplers -import re def draw_xy_grid(xs, ys, x_label, y_label, cell): @@ -45,11 +44,6 @@ class Script(scripts.Script): def title(self): return "Prompt matrix" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def ui(self, is_img2img): put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start")) different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 5c84c3e9..2751f98a 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -13,7 +13,6 @@ from modules import sd_samplers from modules.processing import Processed, process_images from PIL import Image from modules.shared import opts, cmd_opts, state -import re def process_string_tag(tag): @@ -112,11 +111,6 @@ class Script(scripts.Script): def title(self): return "Prompts from file or textbox" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def ui(self, is_img2img): checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 247e755b..9b8ffd85 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -7,18 +7,12 @@ from PIL import Image from modules import processing, shared, sd_samplers, images, devices from modules.processing import Processed from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "SD upscale" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index b277a439..f04d9b7e 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -290,11 +290,6 @@ class Script(scripts.Script): def title(self): return "X/Y plot" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def ui(self, is_img2img): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] -- cgit v1.2.3 From eea8fc40e16664ddc8a9aec77206da704a35dde0 Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 5 Jan 2023 07:24:22 -0800 Subject: Add option to save ti settings to file. --- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 30 +++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index e0f44c6d..933cd738 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,6 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), + "save_train_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file when training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 71e07bcc..2bed2ecb 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,6 +1,7 @@ import os import sys import traceback +import inspect import torch import tqdm @@ -229,6 +230,28 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + checkpoint = sd_models.select_checkpoint() + model_name = checkpoint.model_name + model_hash = '[{}]'.format(checkpoint.hash) + + # Get a list of the argument names. + arg_names = inspect.getfullargspec(save_settings_to_file).args + + # Create a list of the argument names to include in the settings string. + names = arg_names[:16] # Include all arguments up until the preview-related ones. + if preview_from_txt2img: + names.extend(arg_names[16:]) # Include all remaining arguments if `preview_from_txt2img` is True. + + # Build the settings string. + settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" + for name in names: + value = locals()[name] + settings_str += f"{name}: {value}\n" + + with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: + fout.write(settings_str + "\n\n") + def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" @@ -292,13 +315,13 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ if initial_step >= steps: shared.state.textinfo = "Model has already been trained beyond specified max steps" return embedding, filename + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) - clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ None if clip_grad: - clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed @@ -306,7 +329,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - + if shared.opts.save_train_settings_to_txt: + save_settings_to_file(initial_step , len(ds) , embedding_name, len(embedding.vec) , learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) -- cgit v1.2.3 From 19a81ac2871ec900fc8b7955bbc2554b6c5ac6b1 Mon Sep 17 00:00:00 2001 From: cat Date: Thu, 5 Jan 2023 20:17:39 +0500 Subject: hires-fix: add "nearest-exact" latent upscale mode. --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index e0f44c6d..b7a3ce5c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -576,6 +576,7 @@ latent_upscale_modes = { "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True}, "Latent (nearest)": {"mode": "nearest", "antialias": False}, + "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False}, } sd_upscalers = [] -- cgit v1.2.3 From b85c2b5cf4a6809bc871718cf4680d49c3e95e94 Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 5 Jan 2023 08:14:38 -0800 Subject: Clean up ti, add same behavior to hypernetwork. --- modules/hypernetworks/hypernetwork.py | 31 +++++++++++++++++++++++++- modules/shared.py | 2 +- modules/textual_inversion/textual_inversion.py | 14 +++++++----- 3 files changed, 40 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 6a9b1398..d5985263 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -401,7 +401,33 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, hypernet.save(fn) shared.reload_hypernetworks() +# Note: textual_inversion.py has a nearly identical function of the same name. +def save_settings_to_file(initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + checkpoint = sd_models.select_checkpoint() + model_name = checkpoint.model_name + model_hash = '[{}]'.format(checkpoint.hash) + # Starting index of preview-related arguments. + border_index = 19 + + # Get a list of the argument names, excluding default argument. + sig = inspect.signature(save_settings_to_file) + arg_names = [p.name for p in sig.parameters.values() if p.default == p.empty] + + # Create a list of the argument names to include in the settings string. + names = arg_names[:border_index] # Include all arguments up until the preview-related ones. + + # Include preview-related arguments if applicable. + if preview_from_txt2img: + names.extend(arg_names[border_index:]) + + # Build the settings string. + settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" + for name in names: + value = locals()[name] + settings_str += f"{name}: {value}\n" + with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: + fout.write(settings_str + "\n\n") def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. @@ -457,7 +483,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - + + if shared.opts.save_training_settings_to_txt: + save_settings_to_file(initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) diff --git a/modules/shared.py b/modules/shared.py index 933cd738..10231a75 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,7 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), - "save_train_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file when training starts."), + "save_training_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file whenever training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 2bed2ecb..68648550 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -230,18 +230,20 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +# Note: hypernetwork.py has a nearly identical function of the same name. def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): checkpoint = sd_models.select_checkpoint() model_name = checkpoint.model_name model_hash = '[{}]'.format(checkpoint.hash) - + # Starting index of preview-related arguments. + border_index = 16 # Get a list of the argument names. arg_names = inspect.getfullargspec(save_settings_to_file).args # Create a list of the argument names to include in the settings string. - names = arg_names[:16] # Include all arguments up until the preview-related ones. + names = arg_names[:border_index] # Include all arguments up until the preview-related ones. if preview_from_txt2img: - names.extend(arg_names[16:]) # Include all remaining arguments if `preview_from_txt2img` is True. + names.extend(arg_names[border_index:]) # Include all remaining arguments if `preview_from_txt2img` is True. # Build the settings string. settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" @@ -329,8 +331,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - if shared.opts.save_train_settings_to_txt: - save_settings_to_file(initial_step , len(ds) , embedding_name, len(embedding.vec) , learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + + if shared.opts.save_training_settings_to_txt: + save_settings_to_file(initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) -- cgit v1.2.3 From b6bab2f052b32c0ffebe6aecc1819ccf20cf8c5d Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 5 Jan 2023 09:14:56 -0800 Subject: Include model in log file. Exclude directory. --- modules/hypernetworks/hypernetwork.py | 28 +++++++++----------------- modules/textual_inversion/textual_inversion.py | 22 +++++++++----------- 2 files changed, 19 insertions(+), 31 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index d5985263..3237c37a 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -402,30 +402,22 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() # Note: textual_inversion.py has a nearly identical function of the same name. -def save_settings_to_file(initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - checkpoint = sd_models.select_checkpoint() - model_name = checkpoint.model_name - model_hash = '[{}]'.format(checkpoint.hash) +def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # Starting index of preview-related arguments. - border_index = 19 - - # Get a list of the argument names, excluding default argument. - sig = inspect.signature(save_settings_to_file) - arg_names = [p.name for p in sig.parameters.values() if p.default == p.empty] - + border_index = 21 + # Get a list of the argument names. + arg_names = inspect.getfullargspec(save_settings_to_file).args # Create a list of the argument names to include in the settings string. names = arg_names[:border_index] # Include all arguments up until the preview-related ones. - - # Include preview-related arguments if applicable. if preview_from_txt2img: - names.extend(arg_names[border_index:]) - + names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. # Build the settings string. settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" for name in names: - value = locals()[name] - settings_str += f"{name}: {value}\n" - + if name != 'log_directory': # It's useless and redundant to save log_directory. + value = locals()[name] + settings_str += f"{name}: {value}\n" + # Create or append to the file. with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: fout.write(settings_str + "\n\n") @@ -485,7 +477,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) latent_sampling_method = ds.latent_sampling_method diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 68648550..ce7e4f5d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -231,26 +231,22 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) # Note: hypernetwork.py has a nearly identical function of the same name. -def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - checkpoint = sd_models.select_checkpoint() - model_name = checkpoint.model_name - model_hash = '[{}]'.format(checkpoint.hash) +def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # Starting index of preview-related arguments. - border_index = 16 + border_index = 18 # Get a list of the argument names. - arg_names = inspect.getfullargspec(save_settings_to_file).args - + arg_names = inspect.getfullargspec(save_settings_to_file).args # Create a list of the argument names to include in the settings string. names = arg_names[:border_index] # Include all arguments up until the preview-related ones. if preview_from_txt2img: - names.extend(arg_names[border_index:]) # Include all remaining arguments if `preview_from_txt2img` is True. - + names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. # Build the settings string. settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" for name in names: - value = locals()[name] - settings_str += f"{name}: {value}\n" - + if name != 'log_directory': # It's useless and redundant to save log_directory. + value = locals()[name] + settings_str += f"{name}: {value}\n" + # Create or append to the file. with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: fout.write(settings_str + "\n\n") @@ -333,7 +329,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) latent_sampling_method = ds.latent_sampling_method -- cgit v1.2.3 From fda04e620d529031e2134520e74756d0efa30464 Mon Sep 17 00:00:00 2001 From: Kuma <36082288+KumiIT@users.noreply.github.com> Date: Thu, 5 Jan 2023 18:44:19 +0100 Subject: typo in TI --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 71e07bcc..24b43045 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -298,7 +298,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ None if clip_grad: - clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed -- cgit v1.2.3 From 847f869c67c7108e3e792fc193331d0e6acca29c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 21:00:52 +0300 Subject: experimental optimization --- modules/processing.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 61e97077..a408d622 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -544,6 +544,29 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] + cached_uc = [None, None] + cached_c = [None, None] + + def get_conds_with_caching(function, required_prompts, steps, cache): + """ + Returns the result of calling function(shared.sd_model, required_prompts, steps) + using a cache to store the result if the same arguments have been used before. + + cache is an array containing two elements. The first element is a tuple + representing the previously used arguments, or None if no arguments + have been used before. The second element is where the previously + computed result is stored. + """ + + if cache[0] is not None and (required_prompts, steps) == cache[0]: + return cache[1] + + with devices.autocast(): + cache[1] = function(shared.sd_model, required_prompts, steps) + + cache[0] = (required_prompts, steps) + return cache[1] + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) @@ -571,9 +594,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) - with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps) - c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) + uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc) + c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: -- cgit v1.2.3 From 81133d4168ae0bae9bf8bf1a1d4983319a589112 Mon Sep 17 00:00:00 2001 From: Faber Date: Fri, 6 Jan 2023 03:38:37 +0700 Subject: allow loading embeddings from subdirectories --- modules/textual_inversion/textual_inversion.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 24b43045..0a059044 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -149,19 +149,20 @@ class EmbeddingDatabase: else: self.skipped_embeddings[name] = embedding - for fn in os.listdir(self.embeddings_dir): - try: - fullfn = os.path.join(self.embeddings_dir, fn) - - if os.stat(fullfn).st_size == 0: + for root, dirs, fns in os.walk(self.embeddings_dir): + for fn in fns: + try: + fullfn = os.path.join(root, fn) + + if os.stat(fullfn).st_size == 0: + continue + + process_file(fullfn, fn) + except Exception: + print(f"Error loading embedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) continue - process_file(fullfn, fn) - except Exception: - print(f"Error loading embedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - continue - print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if len(self.skipped_embeddings) > 0: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") -- cgit v1.2.3 From b5253f0dab529707f1fe2e11211a10ce2f264617 Mon Sep 17 00:00:00 2001 From: noodleanon <122053346+noodleanon@users.noreply.github.com> Date: Thu, 5 Jan 2023 21:21:48 +0000 Subject: allow img2img api to run scripts --- modules/api/api.py | 27 ++++++++++++++++++++++++--- modules/api/models.py | 2 +- modules/processing.py | 4 ++-- 3 files changed, 27 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 2103709b..aa62a42e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,7 +11,7 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.extras import run_extras @@ -28,8 +28,13 @@ def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) except: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") + raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}") +def script_name_to_index(name, scripts): + try: + return [script.title().lower() for script in scripts].index(name.lower()) + except: + raise HTTPException(status_code=422, detail=f"Script '{name}' not found") def validate_sampler_name(name): config = sd_samplers.all_samplers_map.get(name, None) @@ -170,6 +175,14 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") + if img2imgreq.script_name is not None: + if scripts.scripts_img2img.scripts == []: + scripts.scripts_img2img.initialize_scripts(True) + ui.create_ui() + + script_idx = script_name_to_index(img2imgreq.script_name, scripts.scripts_img2img.selectable_scripts) + script = scripts.scripts_img2img.selectable_scripts[script_idx] + mask = img2imgreq.mask if mask: mask = decode_base64_to_image(mask) @@ -186,13 +199,21 @@ class Api: args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. + args.pop('script_name', None) with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] shared.state.begin() - processed = process_images(p) + if 'script' in locals(): + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples + p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) + else: + processed = process_images(p) + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) diff --git a/modules/api/models.py b/modules/api/models.py index d8198a27..862477e7 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -106,7 +106,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] ).generate_model() class TextToImageResponse(BaseModel): diff --git a/modules/processing.py b/modules/processing.py index a408d622..d5ac7eb1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -98,7 +98,7 @@ class StableDiffusionProcessing(): """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing """ - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -149,7 +149,7 @@ class StableDiffusionProcessing(): self.seed_resize_from_w = 0 self.scripts = None - self.script_args = None + self.script_args = script_args self.all_prompts = None self.all_negative_prompts = None self.all_seeds = None -- cgit v1.2.3 From 8111b5569d07c7ac3b695e28171aede728b4ae56 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 3 Jan 2023 20:43:05 -0500 Subject: Add support for PyTorch nightly and local builds --- modules/devices.py | 28 +++++++++++++++++++++++----- webui.py | 7 ++++++- 2 files changed, 29 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 800510b7..caeb0276 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs): return orig_tensor_numpy(self, *args, **kwargs) -# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working -if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): - torch.Tensor.to = tensor_to_fix - torch.nn.functional.layer_norm = layer_norm_fix - torch.Tensor.numpy = numpy_fix +# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 +orig_cumsum = torch.cumsum +orig_Tensor_cumsum = torch.Tensor.cumsum +def cumsum_fix(input, cumsum_func, *args, **kwargs): + if input.device.type == 'mps': + output_dtype = kwargs.get('dtype', input.dtype) + if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]): + return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + return cumsum_func(input, *args, **kwargs) + + +if has_mps(): + if version.parse(torch.__version__) < version.parse("1.13"): + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix + torch.Tensor.numpy = numpy_fix + elif version.parse(torch.__version__) > version.parse("1.13.1"): + if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)): + torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) + torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) + orig_narrow = torch.narrow + torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) diff --git a/webui.py b/webui.py index 13375e71..ddfaea95 100644 --- a/webui.py +++ b/webui.py @@ -4,7 +4,7 @@ import threading import time import importlib import signal -import threading +import re from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -13,6 +13,11 @@ from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path +import torch +# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors +if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0) + from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer import modules.extras -- cgit v1.2.3 From d61a5aa4f623f6630670241aca8fc5c2a6381769 Mon Sep 17 00:00:00 2001 From: acncagua Date: Fri, 6 Jan 2023 10:58:22 +0900 Subject: Add files via upload --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 81d96c5b..030f0685 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -550,6 +550,8 @@ Requested path was: {f} os.startfile(path) elif platform.system() == "Darwin": sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) else: sp.Popen(["xdg-open", path]) -- cgit v1.2.3 From d782a95967c9eea753df3333cd1954b6ec73eba0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 27 Dec 2022 08:50:55 -0500 Subject: Add Birch-san's sub-quadratic attention implementation --- README.md | 1 + modules/sd_hijack.py | 15 ++- modules/sd_hijack_optimizations.py | 124 ++++++++++++++++++----- modules/shared.py | 4 + modules/sub_quadratic_attention.py | 201 +++++++++++++++++++++++++++++++++++++ requirements.txt | 2 +- 6 files changed, 312 insertions(+), 35 deletions(-) create mode 100644 modules/sub_quadratic_attention.py (limited to 'modules') diff --git a/README.md b/README.md index 556000fb..1913caf3 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ The documentation was moved from this README over to the project's [wiki](https: - Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) +- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san), Amin Rezaei (https://github.com/AminRezaei0x443) - Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 690a9ec2..019a6f3f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet -from modules.sd_hijack_optimizations import invokeAI_mps_available - import ldm.modules.attention import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.openaimodel @@ -40,17 +38,16 @@ def apply_optimizations(): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + elif cmd_opts.opt_sub_quad_attention: + print("Applying sub-quadratic cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): - if not invokeAI_mps_available and shared.device.type == 'mps': - print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") - print("Applying v1 cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - else: - print("Applying cross attention optimization (InvokeAI).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI + print("Applying cross attention optimization (InvokeAI).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 02c87f40..f5c153e8 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,7 @@ import math import sys import traceback -import importlib +import psutil import torch from torch import einsum @@ -12,6 +12,8 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork +from .sub_quadratic_attention import efficient_dot_product_attention + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: print(traceback.format_exc(), file=sys.stderr) +def get_available_vram(): + if shared.device.type == 'cuda': + stats = torch.cuda.memory_stats(shared.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_total + else: + return psutil.virtual_memory().available + + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): h = self.heads @@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() @@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def check_for_psutil(): - try: - spec = importlib.util.find_spec('psutil') - return spec is not None - except ModuleNotFoundError: - return False - -invokeAI_mps_available = check_for_psutil() - # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- -if invokeAI_mps_available: - import psutil - mem_total_gb = psutil.virtual_memory().total // (1 << 30) +mem_total_gb = psutil.virtual_memory().total // (1 << 30) def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) @@ -215,6 +214,70 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # -- End of code from https://github.com/invoke-ai/InvokeAI -- + +# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +def sub_quad_attention_forward(self, x, context=None, mask=None): + assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." + + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, context_k, context_v, x + + q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) + + out_proj, dropout = self.to_out + x = out_proj(x) + x = dropout(x) + + return x + +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True): + bytes_per_token = torch.finfo(q.dtype).bits//8 + batch_x_heads, q_tokens, _ = q.shape + _, k_tokens, _ = k.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + + if chunk_threshold_bytes is None: + chunk_threshold_bytes = available_vram + elif chunk_threshold_bytes == 0: + chunk_threshold_bytes = None + + if kv_chunk_size_min is None: + kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) + elif kv_chunk_size_min == 0: + kv_chunk_size_min = None + + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) + + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) @@ -252,12 +315,7 @@ def cross_attention_attnblock_forward(self, x): h_ = torch.zeros_like(k, device=q.device) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() mem_required = tensor_size * 2.5 @@ -312,3 +370,19 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) + +def sub_quad_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..487a7792 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -56,6 +56,10 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") +parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") +parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) +parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) +parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py new file mode 100644 index 00000000..b11dc1c7 --- /dev/null +++ b/modules/sub_quadratic_attention.py @@ -0,0 +1,201 @@ +# original source: +# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py +# license: +# unspecified +# credit: +# Amin Rezaei (original author) +# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# implementation of: +# Self-attention Does Not Need O(n2) Memory": +# https://arxiv.org/abs/2112.05682v2 + +from functools import partial +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint +import math +from typing import Optional, NamedTuple, Protocol, List + +def dynamic_slice( + x: Tensor, + starts: List[int], + sizes: List[int], +) -> Tensor: + slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] + return x[slicing] + +class AttnChunk(NamedTuple): + exp_values: Tensor + exp_weights_sum: Tensor + max_score: Tensor + +class SummarizeChunk(Protocol): + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> AttnChunk: ... + +class ComputeQueryChunkAttn(Protocol): + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> Tensor: ... + +def _summarize_chunk( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> AttnChunk: + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + max_score, _ = torch.max(attn_weights, -1, keepdim=True) + max_score = max_score.detach() + exp_weights = torch.exp(attn_weights - max_score) + exp_values = torch.bmm(exp_weights, value) + max_score = max_score.squeeze(-1) + return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + +def _query_chunk_attention( + query: Tensor, + key: Tensor, + value: Tensor, + summarize_chunk: SummarizeChunk, + kv_chunk_size: int, +) -> Tensor: + batch_x_heads, k_tokens, k_channels_per_head = key.shape + _, _, v_channels_per_head = value.shape + + def chunk_scanner(chunk_idx: int) -> AttnChunk: + key_chunk = dynamic_slice( + key, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, k_channels_per_head) + ) + value_chunk = dynamic_slice( + value, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, v_channels_per_head) + ) + return summarize_chunk(query, key_chunk, value_chunk) + + chunks: List[AttnChunk] = [ + chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + ] + acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) + chunk_values, chunk_weights, chunk_max = acc_chunk + + global_max, _ = torch.max(chunk_max, 0, keepdim=True) + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= torch.unsqueeze(max_diffs, -1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(dim=0) + all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) + return all_values / all_weights + +# TODO: refactor CrossAttention#get_attention_scores to share code with this +def _get_attention_scores_no_kv_chunking( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> Tensor: + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + attn_probs = attn_scores.softmax(dim=-1) + del attn_scores + hidden_states_slice = torch.bmm(attn_probs, value) + return hidden_states_slice + +class ScannedChunk(NamedTuple): + chunk_idx: int + attn_chunk: AttnChunk + +def efficient_dot_product_attention( + query: Tensor, + key: Tensor, + value: Tensor, + query_chunk_size=1024, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + use_checkpoint=True, +): + """Computes efficient dot-product attention given query, key, and value. + This is efficient version of attention presented in + https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. + Args: + query: queries for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + key: keys for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + value: values to be used in attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + query_chunk_size: int: query chunks size + kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) + kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) + Returns: + Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. + """ + batch_x_heads, q_tokens, q_channels_per_head = query.shape + _, k_tokens, _ = key.shape + scale = q_channels_per_head ** -0.5 + + kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + + def get_query_chunk(chunk_idx: int) -> Tensor: + return dynamic_slice( + query, + (0, chunk_idx, 0), + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + ) + + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) + summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk + compute_query_chunk_attn: ComputeQueryChunkAttn = partial( + _get_attention_scores_no_kv_chunking, + scale=scale + ) if k_tokens <= kv_chunk_size else ( + # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) + partial( + _query_chunk_attention, + kv_chunk_size=kv_chunk_size, + summarize_chunk=summarize_chunk, + ) + ) + + if q_tokens <= query_chunk_size: + # fast-path for when there's just 1 query chunk + return compute_query_chunk_attn( + query=query, + key=key, + value=value, + ) + + # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, + # and pass slices to be mutated, instead of torch.cat()ing the returned slices + res = torch.cat([ + compute_query_chunk_attn( + query=get_query_chunk(i * query_chunk_size), + key=key, + value=value, + ) for i in range(math.ceil(q_tokens / query_chunk_size)) + ], dim=1) + return res diff --git a/requirements.txt b/requirements.txt index 5bed694e..0dbea322 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,4 @@ inflection GitPython torchsde safetensors -psutil; sys_platform == 'darwin' +psutil -- cgit v1.2.3 From b119815333026164f2bd7d1ca71f3e4f7a9afd0d Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 5 Jan 2023 04:37:17 -0500 Subject: Use narrow instead of dynamic_slice --- modules/sub_quadratic_attention.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index b11dc1c7..95924d24 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -5,6 +5,7 @@ # credit: # Amin Rezaei (original author) # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# brkirch (modified to use torch.narrow instead of dynamic_slice implementation) # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 @@ -16,13 +17,13 @@ from torch.utils.checkpoint import checkpoint import math from typing import Optional, NamedTuple, Protocol, List -def dynamic_slice( - x: Tensor, - starts: List[int], - sizes: List[int], +def narrow_trunc( + input: Tensor, + dim: int, + start: int, + length: int ) -> Tensor: - slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] - return x[slicing] + return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) class AttnChunk(NamedTuple): exp_values: Tensor @@ -76,15 +77,17 @@ def _query_chunk_attention( _, _, v_channels_per_head = value.shape def chunk_scanner(chunk_idx: int) -> AttnChunk: - key_chunk = dynamic_slice( + key_chunk = narrow_trunc( key, - (0, chunk_idx, 0), - (batch_x_heads, kv_chunk_size, k_channels_per_head) + 1, + chunk_idx, + kv_chunk_size ) - value_chunk = dynamic_slice( + value_chunk = narrow_trunc( value, - (0, chunk_idx, 0), - (batch_x_heads, kv_chunk_size, v_channels_per_head) + 1, + chunk_idx, + kv_chunk_size ) return summarize_chunk(query, key_chunk, value_chunk) @@ -161,10 +164,11 @@ def efficient_dot_product_attention( kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) def get_query_chunk(chunk_idx: int) -> Tensor: - return dynamic_slice( + return narrow_trunc( query, - (0, chunk_idx, 0), - (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + 1, + chunk_idx, + min(query_chunk_size, q_tokens) ) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) -- cgit v1.2.3 From 683287d87f6401083a8d63eedc00ca7410214ca1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 08:52:06 +0300 Subject: rework saving training params to file #6372 --- modules/hypernetworks/hypernetwork.py | 28 +++++++------------------- modules/shared.py | 2 +- modules/textual_inversion/logging.py | 24 ++++++++++++++++++++++ modules/textual_inversion/textual_inversion.py | 23 +++------------------ 4 files changed, 35 insertions(+), 42 deletions(-) create mode 100644 modules/textual_inversion/logging.py (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3237c37a..b0cfbe71 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -13,7 +13,7 @@ import tqdm from einops import rearrange, repeat from ldm.util import default from modules import devices, processing, sd_models, shared, sd_samplers -from modules.textual_inversion import textual_inversion +from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_ @@ -401,25 +401,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, hypernet.save(fn) shared.reload_hypernetworks() -# Note: textual_inversion.py has a nearly identical function of the same name. -def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # Starting index of preview-related arguments. - border_index = 21 - # Get a list of the argument names. - arg_names = inspect.getfullargspec(save_settings_to_file).args - # Create a list of the argument names to include in the settings string. - names = arg_names[:border_index] # Include all arguments up until the preview-related ones. - if preview_from_txt2img: - names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. - # Build the settings string. - settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" - for name in names: - if name != 'log_directory': # It's useless and redundant to save log_directory. - value = locals()[name] - settings_str += f"{name}: {value}\n" - # Create or append to the file. - with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: - fout.write(settings_str + "\n\n") + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. @@ -477,7 +459,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + saved_params = dict( + model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), + **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]} + ) + logging.save_settings_to_file(log_directory, {**saved_params, **locals()}) latent_sampling_method = ds.latent_sampling_method diff --git a/modules/shared.py b/modules/shared.py index f0e10b35..57e489d0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,7 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), - "save_training_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file whenever training starts."), + "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py new file mode 100644 index 00000000..8b1981d5 --- /dev/null +++ b/modules/textual_inversion/logging.py @@ -0,0 +1,24 @@ +import datetime +import json +import os + +saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"} +saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} +saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} +saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet +saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"} + + +def save_settings_to_file(log_directory, all_params): + now = datetime.datetime.now() + params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")} + + keys = saved_params_all + if all_params.get('preview_from_txt2img'): + keys = keys | saved_params_previews + + params.update({k: v for k, v in all_params.items() if k in keys}) + + filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json' + with open(os.path.join(log_directory, filename), "w") as file: + json.dump(params, file, indent=4) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e9cf432f..f9f5e8cd 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay) +from modules.textual_inversion.logging import save_settings_to_file + class Embedding: def __init__(self, vec, name, step=None): @@ -231,25 +233,6 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) -# Note: hypernetwork.py has a nearly identical function of the same name. -def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # Starting index of preview-related arguments. - border_index = 18 - # Get a list of the argument names. - arg_names = inspect.getfullargspec(save_settings_to_file).args - # Create a list of the argument names to include in the settings string. - names = arg_names[:border_index] # Include all arguments up until the preview-related ones. - if preview_from_txt2img: - names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. - # Build the settings string. - settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" - for name in names: - if name != 'log_directory': # It's useless and redundant to save log_directory. - value = locals()[name] - settings_str += f"{name}: {value}\n" - # Create or append to the file. - with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: - fout.write(settings_str + "\n\n") def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" @@ -330,7 +313,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) latent_sampling_method = ds.latent_sampling_method -- cgit v1.2.3 From b95a4c0ce5ab9c414e0494193bfff665f45e9e65 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 01:01:51 -0500 Subject: Change sub-quad chunk threshold to use percentage --- modules/sd_hijack_optimizations.py | 18 +++++++++--------- modules/shared.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index f5c153e8..b416e9ac 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -233,7 +233,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) - x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) @@ -243,20 +243,20 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): return x -def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True): +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): bytes_per_token = torch.finfo(q.dtype).bits//8 batch_x_heads, q_tokens, _ = q.shape _, k_tokens, _ = k.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) - - if chunk_threshold_bytes is None: - chunk_threshold_bytes = available_vram - elif chunk_threshold_bytes == 0: + if chunk_threshold is None: + chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + elif chunk_threshold == 0: chunk_threshold_bytes = None + else: + chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) - if kv_chunk_size_min is None: + if kv_chunk_size_min is None and chunk_threshold_bytes is not None: kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) elif kv_chunk_size_min == 0: kv_chunk_size_min = None @@ -382,7 +382,7 @@ def sub_quad_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index cb1dc312..d7a81db1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -59,7 +59,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) -parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None) +parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") -- cgit v1.2.3 From 5deb2a19ccea57a50252e8fcb07b4d17c6599def Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 01:33:15 -0500 Subject: Allow Doggettx's cross attention opt without CUDA --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ef25dadb..bd101e5b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -50,7 +50,7 @@ def apply_optimizations(): print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 optimization_method = 'V1' - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()): print("Applying cross attention optimization (InvokeAI).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI optimization_method = 'InvokeAI' -- cgit v1.2.3 From c9bded39ee05bd0507ccd27d2b674d86d6c0c8e8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 12:32:44 +0300 Subject: sort extensions by date and add an option to sort by other columns --- modules/ui_extensions.py | 44 ++++++++++++++++++++++++++++++++------------ style.css | 11 ++++++++++- 2 files changed, 42 insertions(+), 13 deletions(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index eec9586f..742e745e 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -162,15 +162,15 @@ def install_extension_from_url(dirname, url): shutil.rmtree(tmpdir, True) -def install_extension_from_index(url, hide_tags): +def install_extension_from_index(url, hide_tags, sort_column): ext_table, message = install_extension_from_url(None, url) - code, _ = refresh_available_extensions_from_data(hide_tags) + code, _ = refresh_available_extensions_from_data(hide_tags, sort_column) return code, ext_table, message -def refresh_available_extensions(url, hide_tags): +def refresh_available_extensions(url, hide_tags, sort_column): global available_extensions import urllib.request @@ -179,18 +179,28 @@ def refresh_available_extensions(url, hide_tags): available_extensions = json.loads(text) - code, tags = refresh_available_extensions_from_data(hide_tags) + code, tags = refresh_available_extensions_from_data(hide_tags, sort_column) return url, code, gr.CheckboxGroup.update(choices=tags), '' -def refresh_available_extensions_for_tags(hide_tags): - code, _ = refresh_available_extensions_from_data(hide_tags) +def refresh_available_extensions_for_tags(hide_tags, sort_column): + code, _ = refresh_available_extensions_from_data(hide_tags, sort_column) return code, '' -def refresh_available_extensions_from_data(hide_tags): +sort_ordering = [ + # (reverse, order_by_function) + (True, lambda x: x.get('added', 'z')), + (False, lambda x: x.get('added', 'z')), + (False, lambda x: x.get('name', 'z')), + (True, lambda x: x.get('name', 'z')), + (False, lambda x: 'z'), +] + + +def refresh_available_extensions_from_data(hide_tags, sort_column): extlist = available_extensions["extensions"] installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} @@ -210,8 +220,11 @@ def refresh_available_extensions_from_data(hide_tags): """ - for ext in extlist: + sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0] + + for ext in sorted(extlist, key=sort_function, reverse=sort_reverse): name = ext.get("name", "noname") + added = ext.get('added', 'unknown') url = ext.get("url", None) description = ext.get("description", "") extension_tags = ext.get("tags", []) @@ -233,7 +246,7 @@ def refresh_available_extensions_from_data(hide_tags): code += f""" {html.escape(name)}
{tags_text} - {html.escape(description)} + {html.escape(description)}

Added: {html.escape(added)}

{install_code} @@ -291,25 +304,32 @@ def create_ui(): with gr.Row(): hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") install_result = gr.HTML() available_extensions_table = gr.HTML() refresh_available_extensions_button.click( fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), - inputs=[available_extensions_index, hide_tags], + inputs=[available_extensions_index, hide_tags, sort_column], outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result], ) install_extension_button.click( fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), - inputs=[extension_to_install, hide_tags], + inputs=[extension_to_install, hide_tags, sort_column], outputs=[available_extensions_table, extensions_table, install_result], ) hide_tags.change( fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), - inputs=[hide_tags], + inputs=[hide_tags, sort_column], + outputs=[available_extensions_table, install_result] + ) + + sort_column.change( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[hide_tags, sort_column], outputs=[available_extensions_table, install_result] ) diff --git a/style.css b/style.css index ee74d79e..f1b23b53 100644 --- a/style.css +++ b/style.css @@ -555,7 +555,7 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h /* Extensions */ -#tab_extensions table{ +#tab_extensions table``{ border-collapse: collapse; } @@ -581,6 +581,15 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h font-size: 95%; } +#available_extensions .info{ + margin: 0; +} + +#available_extensions .date_added{ + opacity: 0.85; + font-size: 90%; +} + #image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{ min-width: auto; padding-left: 0.5em; -- cgit v1.2.3 From 65ed4421e609dda3112f236c13e4db14caa71364 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 13:55:50 +0300 Subject: add callback for when the script is unloaded --- modules/script_callbacks.py | 18 +++++++++++++++++- webui.py | 2 ++ 2 files changed, 19 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index de69fd9f..608c5300 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], + callbacks_script_unloaded=[], ) @@ -171,6 +172,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') +def script_unloaded_callback(): + for c in reversed(callback_map['callbacks_script_unloaded']): + try: + c.callback() + except Exception: + report_exception(c, 'script_unloaded') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -202,7 +211,7 @@ def on_app_started(callback): def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is - passed as an argument""" + passed as an argument; this function is also called when the script is reloaded. """ add_callback(callback_map['callbacks_model_loaded'], callback) @@ -279,3 +288,10 @@ def on_image_grid(callback): - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. """ add_callback(callback_map['callbacks_image_grid'], callback) + + +def on_script_unloaded(callback): + """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that + the script did should be reverted here""" + + add_callback(callback_map['callbacks_script_unloaded'], callback) diff --git a/webui.py b/webui.py index ff6eb6eb..733a06b5 100644 --- a/webui.py +++ b/webui.py @@ -187,12 +187,14 @@ def webui(): sd_samplers.set_samplers() + modules.script_callbacks.script_unloaded_callback() extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) modelloader.forbid_loaded_nonbuiltin_upscalers() modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) modelloader.load_upscalers() for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: -- cgit v1.2.3 From 3246a2d6b898da6a98fe9df4dc67944635a41bd3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 16:03:43 +0300 Subject: remove restriction for saving dropdowns to ui-config.json --- modules/scripts.py | 1 - modules/ui.py | 10 ++-------- 2 files changed, 2 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 0c44f191..35164093 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -290,7 +290,6 @@ class ScriptRunner: script.group = group dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") - dropdown.save_to_config = True inputs[0] = dropdown for script in self.selectable_scripts: diff --git a/modules/ui.py b/modules/ui.py index 030f0685..b79d24ee 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -435,11 +435,9 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=1, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - prompt_style.save_to_config = True with gr.Column(scale=1, elem_id="style_neg_col"): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - prompt_style2.save_to_config = True return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button @@ -638,7 +636,6 @@ def create_sampler_and_steps_selection(choices, tabname): if opts.samplers_in_dropdown: with FormRow(elem_id=f"sampler_selection_{tabname}"): sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - sampler_index.save_to_config = True steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) else: with FormGroup(elem_id=f"sampler_selection_{tabname}"): @@ -1794,7 +1791,7 @@ def create_ui(): if init_field is not None: init_field(saved_value) - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: apply_field(x, 'visible') if type(x) == gr.Slider: @@ -1815,11 +1812,8 @@ def create_ui(): if type(x) == gr.Number: apply_field(x, 'value') - # Since there are many dropdowns that shouldn't be saved, - # we only mark dropdowns that should be saved. - if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False): + if type(x) == gr.Dropdown: apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - apply_field(x, 'visible') visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") -- cgit v1.2.3 From 50194de93ffc9db763d9b08fcc9c3bde1aa86151 Mon Sep 17 00:00:00 2001 From: Kuma <36082288+KumiIT@users.noreply.github.com> Date: Fri, 6 Jan 2023 16:12:45 +0100 Subject: typo UI fixes #6391 --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 57e489d0..865c3c07 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -430,7 +430,7 @@ options_templates.update(options_section(('ui', "User interface"), { "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), - 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"), + 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) -- cgit v1.2.3 From 3992ecbe6e46a465062508c677964534e7397f72 Mon Sep 17 00:00:00 2001 From: Mitchell Boot <47387831+Mitchell1711@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:02:46 +0100 Subject: Added UI elements Added a new row to hires fix that shows the new resolution after scaling --- modules/ui.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index b79d24ee..20f7d2a2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -255,6 +255,12 @@ def add_style(name: str, prompt: str, negative_prompt: str): return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] +def calc_resolution_hires(x, y, scale): + #final res can only be a multiple of 8 + scaled_x = int(x * scale // 8) * 8 + scaled_y = int(y * scale // 8) * 8 + + return "

Upscaled Resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) @@ -718,6 +724,12 @@ def create_ui(): hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + + with FormRow(elem_id="txt2img_hires_fix_row3"): + hr_final_resolution = gr.HTML(value=calc_resolution_hires(width.value, height.value, hr_scale.value), elem_id="txtimg_hr_finalres") + hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) elif category == "batch": if not opts.dimensions_and_batch_together: -- cgit v1.2.3 From 991368c8d54404d8e13d4c6e76a0f32644e65ad4 Mon Sep 17 00:00:00 2001 From: Mitchell Boot <47387831+Mitchell1711@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:24:29 +0100 Subject: remove camelcase --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 20f7d2a2..6fc8b7d7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -260,7 +260,7 @@ def calc_resolution_hires(x, y, scale): scaled_x = int(x * scale // 8) * 8 scaled_y = int(y * scale // 8) * 8 - return "

Upscaled Resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" + return "

Upscaled resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) -- cgit v1.2.3 From c18add68ef7d2de3617cbbaff864b0c74cfdf6c0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 16:42:47 -0500 Subject: Added license --- html/licenses.html | 29 ++++++++++++++++++++++++++++- modules/sd_hijack_optimizations.py | 1 + modules/sub_quadratic_attention.py | 2 +- 3 files changed, 30 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/html/licenses.html b/html/licenses.html index 9eeaa072..570630eb 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -184,7 +184,7 @@ SOFTWARE.

SwinIR

-Code added by contirubtors, most likely copied from this repository. +Code added by contributors, most likely copied from this repository.
                                  Apache License
@@ -390,3 +390,30 @@ SOFTWARE.
    limitations under the License.
 
+

Memory Efficient Attention

+The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that. +
+MIT License
+
+Copyright (c) 2023 Alex Birch
+Copyright (c) 2023 Amin Rezaei
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+ diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index b416e9ac..cdc63ed7 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -216,6 +216,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface def sub_quad_attention_forward(self, x, context=None, mask=None): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 95924d24..fea7aaac 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -1,7 +1,7 @@ # original source: # https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py # license: -# unspecified +# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license) # credit: # Amin Rezaei (original author) # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) -- cgit v1.2.3 From 82c1f10b144f733460feead0bdc37a861489dc57 Mon Sep 17 00:00:00 2001 From: Dean Hopkins Date: Fri, 6 Jan 2023 22:00:12 +0000 Subject: increase upscale api validation limit --- modules/api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index f77951fc..22b88c59 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -125,7 +125,7 @@ class ExtrasBaseRequest(BaseModel): gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") - upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?") -- cgit v1.2.3 From 79e39fae6110c20a3ee6255e2841c877f65e8cbd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 01:45:28 +0300 Subject: CLIP hijack rework --- modules/sd_hijack.py | 6 +- modules/sd_hijack_clip.py | 348 ++++++++++++------------- modules/sd_hijack_clip_old.py | 81 ++++++ modules/textual_inversion/textual_inversion.py | 1 - modules/ui.py | 2 +- 5 files changed, 256 insertions(+), 182 deletions(-) create mode 100644 modules/sd_hijack_clip_old.py (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fa2cd4bb..71cc145a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -150,10 +150,10 @@ class StableDiffusionModelHijack: def clear_comments(self): self.comments = [] - def tokenize(self, text): - _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) + def get_prompt_lengths(self, text): + _, token_count = self.clip.process_texts([text]) - return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count) + return token_count, self.clip.get_target_prompt_token_count(token_count) class EmbeddingsWithFixes(torch.nn.Module): diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index ca92b142..ac3020d7 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -1,12 +1,28 @@ import math +from collections import namedtuple import torch from modules import prompt_parser, devices from modules.shared import opts -def get_target_prompt_token_count(token_count): - return math.ceil(max(token_count, 1) / 75) * 75 + +class PromptChunk: + """ + This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt. + If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary. + Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token, + so just 75 tokens from prompt. + """ + + def __init__(self): + self.tokens = [] + self.multipliers = [] + self.fixes = [] + + +PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) +"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk""" class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): @@ -14,17 +30,49 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): super().__init__() self.wrapped = wrapped self.hijack = hijack + self.chunk_length = 75 + + def empty_chunk(self): + """creates an empty PromptChunk and returns it""" + + chunk = PromptChunk() + chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) + chunk.multipliers = [1.0] * (self.chunk_length + 2) + return chunk + + def get_target_prompt_token_count(self, token_count): + """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented""" + + return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length def tokenize(self, texts): + """Converts a batch of texts into a batch of token ids""" + raise NotImplementedError def encode_with_transformers(self, tokens): + """ + converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens; + All python lists with tokens are assumed to have same length, usually 77. + if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on + model - can be 768 and 1024 + """ + raise NotImplementedError def encode_embedding_init_text(self, init_text, nvpt): + """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through + transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned.""" + raise NotImplementedError - def tokenize_line(self, line, used_custom_terms, hijack_comments): + def tokenize_line(self, line): + """ + this transforms a single prompt into a list of PromptChunk objects - as many as needed to + represent the prompt. + Returns the list and the total number of tokens in the prompt. + """ + if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) else: @@ -32,205 +80,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): tokenized = self.tokenize([text for text, _ in parsed]) - fixes = [] - remade_tokens = [] - multipliers = [] + chunks = [] + chunk = PromptChunk() + token_count = 0 last_comma = -1 - for tokens, (text, weight) in zip(tokenized, parsed): - i = 0 - while i < len(tokens): - token = tokens[i] + def next_chunk(): + """puts current chunk into the list of results and produces the next one - empty""" + nonlocal token_count + nonlocal last_comma + nonlocal chunk + + token_count += len(chunk.tokens) + to_add = self.chunk_length - len(chunk.tokens) + if to_add > 0: + chunk.tokens += [self.id_end] * to_add + chunk.multipliers += [1.0] * to_add - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end] + chunk.multipliers = [1.0] + chunk.multipliers + [1.0] + + last_comma = -1 + chunks.append(chunk) + chunk = PromptChunk() + + for tokens, (text, weight) in zip(tokenized, parsed): + position = 0 + while position < len(tokens): + token = tokens[position] if token == self.comma_token: - last_comma = len(remade_tokens) - elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: - last_comma += 1 - reloc_tokens = remade_tokens[last_comma:] - reloc_mults = multipliers[last_comma:] + last_comma = len(chunk.tokens) + + # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack + # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next. + elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: + break_location = last_comma + 1 + + reloc_tokens = chunk.tokens[break_location:] + reloc_mults = chunk.multipliers[break_location:] - remade_tokens = remade_tokens[:last_comma] - length = len(remade_tokens) + chunk.tokens = chunk.tokens[:break_location] + chunk.multipliers = chunk.multipliers[:break_location] - rem = int(math.ceil(length / 75)) * 75 - length - remade_tokens += [self.id_end] * rem + reloc_tokens - multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults + next_chunk() + chunk.tokens = reloc_tokens + chunk.multipliers = reloc_mults + if len(chunk.tokens) == self.chunk_length: + next_chunk() + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position) if embedding is None: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - iteration = len(remade_tokens) // 75 - if (len(remade_tokens) + emb_len) // 75 != iteration: - rem = (75 * (iteration + 1) - len(remade_tokens)) - remade_tokens += [self.id_end] * rem - multipliers += [1.0] * rem - iteration += 1 - fixes.append((iteration, (len(remade_tokens) % 75, embedding))) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - token_count = len(remade_tokens) - prompt_target_length = get_target_prompt_token_count(token_count) - tokens_to_add = prompt_target_length - len(remade_tokens) - - remade_tokens = remade_tokens + [self.id_end] * tokens_to_add - multipliers = multipliers + [1.0] * tokens_to_add - - return remade_tokens, fixes, multipliers, token_count - - def process_text(self, texts): - used_custom_terms = [] - remade_batch_tokens = [] - hijack_comments = [] - hijack_fixes = [] + chunk.tokens.append(token) + chunk.multipliers.append(weight) + position += 1 + continue + + emb_len = int(embedding.vec.shape[0]) + if len(chunk.tokens) + emb_len > self.chunk_length: + next_chunk() + + chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding)) + + chunk.tokens += [0] * emb_len + chunk.multipliers += [weight] * emb_len + position += embedding_length_in_tokens + + if len(chunk.tokens) > 0: + next_chunk() + + return chunks, token_count + + def process_texts(self, texts): + """ + Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum + length, in tokens, of all texts. + """ + token_count = 0 cache = {} - batch_multipliers = [] + batch_chunks = [] for line in texts: if line in cache: - remade_tokens, fixes, multipliers = cache[line] + chunks = cache[line] else: - remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + chunks, current_token_count = self.tokenize_line(line) token_count = max(current_token_count, token_count) - cache[line] = (remade_tokens, fixes, multipliers) + cache[line] = chunks - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) + batch_chunks.append(chunks) - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + return batch_chunks, token_count - def process_text_old(self, texts): - id_start = self.id_start - id_end = self.id_end - maxlen = self.wrapped.max_length # you get to stay at 77 - used_custom_terms = [] - remade_batch_tokens = [] - hijack_comments = [] - hijack_fixes = [] - token_count = 0 + def forward(self, texts): + """ + Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. + Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. + An example shape returned by this function can be: (2, 77, 768). + Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet + is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" + """ - cache = {} - batch_tokens = self.tokenize(texts) - batch_multipliers = [] - for tokens in batch_tokens: - tuple_tokens = tuple(tokens) + if opts.use_old_emphasis_implementation: + import modules.sd_hijack_clip_old + return modules.sd_hijack_clip_old.forward_old(self, texts) - if tuple_tokens in cache: - remade_tokens, fixes, multipliers = cache[tuple_tokens] - else: - fixes = [] - remade_tokens = [] - multipliers = [] - mult = 1.0 - - i = 0 - while i < len(tokens): - token = tokens[i] - - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - - mult_change = self.token_mults.get(token) if opts.enable_emphasis else None - if mult_change is not None: - mult *= mult_change - i += 1 - elif embedding is None: - remade_tokens.append(token) - multipliers.append(mult) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - fixes.append((len(remade_tokens), embedding)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - if len(remade_tokens) > maxlen - 2: - vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} - ovf = remade_tokens[maxlen - 2:] - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - - token_count = len(remade_tokens) - remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] - cache[tuple_tokens] = (remade_tokens, fixes, multipliers) - - multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) - multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] - - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - - def forward(self, text): - use_old = opts.use_old_emphasis_implementation - if use_old: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) - else: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - - self.hijack.comments += hijack_comments - - if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - - if use_old: - self.hijack.fixes = hijack_fixes - return self.process_tokens(remade_batch_tokens, batch_multipliers) - - z = None - i = 0 - while max(map(len, remade_batch_tokens)) != 0: - rem_tokens = [x[75:] for x in remade_batch_tokens] - rem_multipliers = [x[75:] for x in batch_multipliers] - - self.hijack.fixes = [] - for unfiltered in hijack_fixes: - fixes = [] - for fix in unfiltered: - if fix[0] == i: - fixes.append(fix[1]) - self.hijack.fixes.append(fixes) - - tokens = [] - multipliers = [] - for j in range(len(remade_batch_tokens)): - if len(remade_batch_tokens[j]) > 0: - tokens.append(remade_batch_tokens[j][:75]) - multipliers.append(batch_multipliers[j][:75]) - else: - tokens.append([self.id_end] * 75) - multipliers.append([1.0] * 75) - - z1 = self.process_tokens(tokens, multipliers) - z = z1 if z is None else torch.cat((z, z1), axis=-2) - - remade_batch_tokens = rem_tokens - batch_multipliers = rem_multipliers - i += 1 + batch_chunks, token_count = self.process_texts(texts) - return z + used_embeddings = {} + chunk_count = max([len(x) for x in batch_chunks]) - def process_tokens(self, remade_batch_tokens, batch_multipliers): - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens] - batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] + zs = [] + for i in range(chunk_count): + batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] + + tokens = [x.tokens for x in batch_chunk] + multipliers = [x.multipliers for x in batch_chunk] + self.hijack.fixes = [x.fixes for x in batch_chunk] + for fixes in self.hijack.fixes: + for position, embedding in fixes: + used_embeddings[embedding.name] = embedding + + z = self.process_tokens(tokens, multipliers) + zs.append(z) + + if len(used_embeddings) > 0: + embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) + self.hijack.comments.append(f"Used embeddings: {embeddings_list}") + + return torch.hstack(zs) + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + """ + sends one single prompt chunk to be encoded by transformers neural network. + remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually + there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens. + Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier + corresponds to one token. + """ tokens = torch.asarray(remade_batch_tokens).to(devices.device) + # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones. if self.id_end != self.id_pad: for batch_pos in range(len(remade_batch_tokens)): index = remade_batch_tokens[batch_pos].index(self.id_end) @@ -239,8 +234,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): z = self.encode_with_transformers(tokens) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] - batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device) + batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py new file mode 100644 index 00000000..6d9fbbe6 --- /dev/null +++ b/modules/sd_hijack_clip_old.py @@ -0,0 +1,81 @@ +from modules import sd_hijack_clip +from modules import shared + + +def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): + id_start = self.id_start + id_end = self.id_end + maxlen = self.wrapped.max_length # you get to stay at 77 + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_tokens = self.tokenize(texts) + batch_multipliers = [] + for tokens in batch_tokens: + tuple_tokens = tuple(tokens) + + if tuple_tokens in cache: + remade_tokens, fixes, multipliers = cache[tuple_tokens] + else: + fixes = [] + remade_tokens = [] + multipliers = [] + mult = 1.0 + + i = 0 + while i < len(tokens): + token = tokens[i] + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + + mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None + if mult_change is not None: + mult *= mult_change + i += 1 + elif embedding is None: + remade_tokens.append(token) + multipliers.append(mult) + i += 1 + else: + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += embedding_length_in_tokens + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + token_count = len(remade_tokens) + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + cache[tuple_tokens] = (remade_tokens, fixes, multipliers) + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + +def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts) + + self.hijack.comments += hijack_comments + + if len(used_custom_terms) > 0: + self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index f9f5e8cd..45882ed6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -79,7 +79,6 @@ class EmbeddingDatabase: self.word_embeddings[embedding.name] = embedding - # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working ids = model.cond_stage_model.tokenize([embedding.name])[0] first_id = ids[0] diff --git a/modules/ui.py b/modules/ui.py index b79d24ee..5d2f5bad 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -368,7 +368,7 @@ def update_token_counter(text, steps): flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) prompts = [prompt_text for step, prompt_text in flat_prompts] - tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) style_class = ' class="red"' if (token_count > max_length) else "" return f"{token_count}/{max_length}" -- cgit v1.2.3 From f94cfc563bbedd923d5e95563a5e8d93c8516ac3 Mon Sep 17 00:00:00 2001 From: Mitchell Boot <47387831+Mitchell1711@users.noreply.github.com> Date: Sat, 7 Jan 2023 01:15:22 +0100 Subject: Changed HTML to textbox instead Using HTML caused an issue where the row would expand for a frame when changing the sliders because of the loading animation. This solution also doesn't use any additional HTML padding --- modules/ui.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 6fc8b7d7..6ea1b5d7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -260,7 +260,7 @@ def calc_resolution_hires(x, y, scale): scaled_x = int(x * scale // 8) * 8 scaled_y = int(y * scale // 8) * 8 - return "

Upscaled resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" + return str(scaled_x)+"x"+str(scaled_y) def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) @@ -726,7 +726,10 @@ def create_ui(): hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") with FormRow(elem_id="txt2img_hires_fix_row3"): - hr_final_resolution = gr.HTML(value=calc_resolution_hires(width.value, height.value, hr_scale.value), elem_id="txtimg_hr_finalres") + hr_final_resolution = gr.Textbox(value=calc_resolution_hires(width.value, height.value, hr_scale.value), + elem_id="txtimg_hr_finalres", + label="Upscaled resolution", + interactive=False) hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) -- cgit v1.2.3 From 08066676a47b560235d4c085dd3cfcb470b80997 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 07:22:07 +0300 Subject: make it not break on empty inputs; thank you tarded, we are --- modules/sd_hijack_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index ac3020d7..16aef76a 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -147,7 +147,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): chunk.multipliers += [weight] * emb_len position += embedding_length_in_tokens - if len(chunk.tokens) > 0: + if len(chunk.tokens) > 0 or len(chunks) == 0: next_chunk() return chunks, token_count -- cgit v1.2.3 From 1740c33547b62f692834c95914a2b295d51684c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 07:48:44 +0300 Subject: more comments --- modules/sd_hijack_clip.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 16aef76a..5520c9b2 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -3,7 +3,7 @@ from collections import namedtuple import torch -from modules import prompt_parser, devices +from modules import prompt_parser, devices, sd_hijack from modules.shared import opts @@ -22,14 +22,24 @@ class PromptChunk: PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) -"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk""" +"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt +chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally +are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): + """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to + have unlimited prompt length and assign weights to tokens in prompt. + """ + def __init__(self, wrapped, hijack): super().__init__() + self.wrapped = wrapped - self.hijack = hijack + """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation, + depending on model.""" + + self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.chunk_length = 75 def empty_chunk(self): @@ -55,7 +65,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens; All python lists with tokens are assumed to have same length, usually 77. if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on - model - can be 768 and 1024 + model - can be 768 and 1024. + Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None). """ raise NotImplementedError @@ -113,7 +124,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): last_comma = len(chunk.tokens) # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack - # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next. + # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: break_location = last_comma + 1 -- cgit v1.2.3 From de9738044571877450d1038e18f1ecce93d24af3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 08:53:53 +0300 Subject: this breaks on default config because width, height, hr_scale are None at that point. --- modules/ui.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index f946382d..a18b9007 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -725,14 +725,8 @@ def create_ui(): hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - with FormRow(elem_id="txt2img_hires_fix_row3"): - hr_final_resolution = gr.Textbox(value=calc_resolution_hires(width.value, height.value, hr_scale.value), - elem_id="txtimg_hr_finalres", - label="Upscaled resolution", - interactive=False) - hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + with FormRow(elem_id="txt2img_hires_fix_row3"): + hr_final_resolution = gr.Textbox(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) elif category == "batch": if not opts.dimensions_and_batch_together: @@ -744,6 +738,10 @@ def create_ui(): with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) -- cgit v1.2.3 From 1a5b86ad65fd738eadea1ad72f4abad3a4aabf17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 09:56:37 +0300 Subject: rework hires fix preview for #6437: movie it to where it takes less place, make it actually account for all relevant sliders and calculate dimensions correctly --- modules/processing.py | 1 - modules/ui.py | 40 +++++++++++++++++++++++++++------------- modules/ui_components.py | 8 ++++++++ style.css | 17 +++++++++++++++++ 4 files changed, 52 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index a408d622..82157bc9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -711,7 +711,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = 0 self.truncate_y = 0 - def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: if self.hr_resize_x == 0 and self.hr_resize_y == 0: diff --git a/modules/ui.py b/modules/ui.py index a18b9007..6c765262 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru -from modules.ui_components import FormRow, FormGroup, ToolButton +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts @@ -255,12 +255,20 @@ def add_style(name: str, prompt: str, negative_prompt: str): return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] -def calc_resolution_hires(x, y, scale): - #final res can only be a multiple of 8 - scaled_x = int(x * scale // 8) * 8 - scaled_y = int(y * scale // 8) * 8 - - return str(scaled_x)+"x"+str(scaled_y) + +def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): + from modules import processing, devices + + if not enable: + return "" + + p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + + with devices.autocast(): + p.init([""], [0], [0]) + + return f"resize to: {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" + def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) @@ -712,6 +720,7 @@ def create_ui(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) elif category == "hires_fix": with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: @@ -724,9 +733,6 @@ def create_ui(): hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - - with FormRow(elem_id="txt2img_hires_fix_row3"): - hr_final_resolution = gr.Textbox(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) elif category == "batch": if not opts.dimensions_and_batch_together: @@ -738,9 +744,16 @@ def create_ui(): with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] + hr_resolution_preview_args = dict( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False + ) + + for input in hr_resolution_preview_inputs: + input.change(**hr_resolution_preview_args) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) @@ -803,6 +816,7 @@ def create_ui(): fn=lambda x: gr_show(x), inputs=[enable_hr], outputs=[hr_options], + show_progress = False, ) txt2img_paste_fields = [ diff --git a/modules/ui_components.py b/modules/ui_components.py index 91eb0e3d..cac001dc 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -23,3 +23,11 @@ class FormGroup(gr.Group, gr.components.FormComponent): def get_block_name(self): return "group" + + +class FormHTML(gr.HTML, gr.components.FormComponent): + """Same as gr.HTML but fits inside gradio forms""" + + def get_block_name(self): + return "html" + diff --git a/style.css b/style.css index f1b23b53..76721756 100644 --- a/style.css +++ b/style.css @@ -642,6 +642,23 @@ footer { opacity: 0.85; } +#txtimg_hr_finalres{ + min-height: 0 !important; + padding: .625rem .75rem; + margin-left: -0.75em + +} + +#txtimg_hr_finalres .resolution{ + font-weight: bold; +} + +#txt2img_checkboxes > div > div{ + flex: 0; + white-space: nowrap; + min-width: auto; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From fdfce4711076c2ebac1089bac8169d043eb7978f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 13:29:47 +0300 Subject: add "from" resolution for hires fix to be less confusing. --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 6c765262..99483130 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz with devices.autocast(): p.init([""], [0], [0]) - return f"resize to: {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" + return f"resize: from {width}x{height} to {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" def apply_styles(prompt, prompt_neg, style1_name, style2_name): -- cgit v1.2.3 From df3b31eb559ab9fabf7e513bdeddd5282c16f124 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 7 Jan 2023 07:04:59 -0500 Subject: In-place operations can break gradient calculation --- modules/sd_hijack_clip.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 5520c9b2..852afc66 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() - z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() - z *= original_mean / new_mean + z = z * (original_mean / new_mean) return z -- cgit v1.2.3 From 47534577eda63b0db1eeb8921c2a161773ec434c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 7 Jan 2023 07:51:35 -0500 Subject: api-get-memory --- modules/api/api.py | 37 +++++++++++++++++++++++++++++++++++++ modules/api/models.py | 4 ++++ 2 files changed, 41 insertions(+) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 2103709b..d2222b18 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -130,6 +130,7 @@ class Api: self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) + self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse) def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -465,6 +466,42 @@ class Api: shared.state.end() return TrainResponse(info = "train embedding error: {error}".format(error = error)) + def get_memory(self): + def gb(val: float): + return round(val / 1024 / 1024 / 1024, 2) + try: + import os, psutil + process = psutil.Process(os.getpid()) + res = process.memory_info() + ram_total = 100 * res.rss / process.memory_percent() + ram = { 'free': gb(ram_total - res.rss), 'used': gb(res.rss), 'total': gb(ram_total) } + except Exception as err: + ram = { 'error': f'{err}' } + try: + import torch + if torch.cuda.is_available(): + s = torch.cuda.mem_get_info() + system = { 'free': gb(s[0]), 'used': gb(s[1] - s[0]), 'total': gb(s[1]) } + s = dict(torch.cuda.memory_stats(shared.device)) + allocated = { 'current': gb(s['allocated_bytes.all.current']), 'peak': gb(s['allocated_bytes.all.peak']) } + reserved = { 'current': gb(s['reserved_bytes.all.current']), 'peak': gb(s['reserved_bytes.all.peak']) } + active = { 'current': gb(s['active_bytes.all.current']), 'peak': gb(s['active_bytes.all.peak']) } + inactive = { 'current': gb(s['inactive_split_bytes.all.current']), 'peak': gb(s['inactive_split_bytes.all.peak']) } + warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] } + cuda = { + 'system': system, + 'active': active, + 'allocated': allocated, + 'reserved': reserved, + 'inactive': inactive, + 'events': warnings, + } + else: + cuda = { 'error': 'unavailable' } + except Exception as err: + cuda = { 'error': f'{err}' } + return MemoryResponse(ram = ram, cuda = cuda) + def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port) diff --git a/modules/api/models.py b/modules/api/models.py index 5fa63774..49bf1e7a 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -260,3 +260,7 @@ class EmbeddingItem(BaseModel): class EmbeddingsResponse(BaseModel): loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model") skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") + +class MemoryResponse(BaseModel): + ram: dict[str, str] | dict[str, float] = Field(title="RAM", description="System memory stats") + cuda: dict[str, str] | dict[str, dict] = Field(title="CUDA", description="nVidia CUDA memory stats") -- cgit v1.2.3 From d38ede71d5330958f4bbac5f99c1be3c146b506a Mon Sep 17 00:00:00 2001 From: noodleanon <122053346+noodleanon@users.noreply.github.com> Date: Sat, 7 Jan 2023 14:21:31 +0000 Subject: Added script support in txt2img endpoint --- modules/api/api.py | 22 +++++++++++++++++++--- modules/api/models.py | 2 +- 2 files changed, 20 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index aa62a42e..0e8ea263 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -149,6 +149,14 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): + if txt2imgreq.script_name is not None: + if scripts.scripts_txt2img.scripts == []: + scripts.scripts_txt2img.initialize_scripts(True) + ui.create_ui() + + script_idx = script_name_to_index(txt2imgreq.script_name, scripts.scripts_txt2img.selectable_scripts) + script = scripts.scripts_txt2img.selectable_scripts[script_idx] + populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), "do_not_save_samples": True, @@ -158,11 +166,20 @@ class Api: if populate.sampler_name: populate.sampler_index = None # prevent a warning later on + args = vars(populate) + args.pop('script_name', None) + with self.queue_lock: - p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate)) + p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) shared.state.begin() - processed = process_images(p) + if 'script' in locals(): + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples + p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args + processed = scripts.scripts_txt2img.run(p, *p.script_args) + else: + processed = process_images(p) shared.state.end() @@ -213,7 +230,6 @@ class Api: processed = scripts.scripts_img2img.run(p, *p.script_args) else: processed = process_images(p) - shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) diff --git a/modules/api/models.py b/modules/api/models.py index c85eb94d..ce43c858 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -100,7 +100,7 @@ class PydanticModelGenerator: StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] ).generate_model() StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( -- cgit v1.2.3 From 448b9cedab66e05b5b2800513ca334a769b42aa7 Mon Sep 17 00:00:00 2001 From: dan Date: Sat, 7 Jan 2023 21:07:27 +0800 Subject: Allow variable img size --- modules/textual_inversion/dataset.py | 18 +++++++++++------- modules/textual_inversion/textual_inversion.py | 4 ++-- 2 files changed, 13 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 88d68c76..375178ed 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None): self.filename = filename self.filename_text = filename_text self.latent_dist = latent_dist @@ -25,6 +25,7 @@ class DatasetEntry: self.cond = cond self.cond_text = cond_text self.pixel_values = pixel_values + self.img_shape = img_shape class PersonalizedBase(Dataset): @@ -33,8 +34,6 @@ class PersonalizedBase(Dataset): self.placeholder_token = placeholder_token - self.width = width - self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] @@ -59,7 +58,11 @@ class PersonalizedBase(Dataset): if shared.state.interrupted: raise Exception("interrupted") try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) + image = Image.open(path).convert('RGB') + if width < 2000: + image = image.resize((width, height), PIL.Image.BICUBIC) + else: + assert batch_size == 1, 'variable img size must have batch size 1' except Exception: continue @@ -88,14 +91,14 @@ class PersonalizedBase(Dataset): if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) latent_sampling_method = "once" - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) elif latent_sampling_method == "deterministic": # Works only for DiagonalGaussianDistribution latent_dist.std = 0 latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) elif latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size) if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) @@ -151,6 +154,7 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + self.img_shape = [entry.img_shape for entry in data] #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 45882ed6..9f96d0fd 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -451,8 +451,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ else: p.prompt = batch.cond_text[0] p.steps = 20 - p.width = training_width - p.height = training_height + p.width = batch.img_shape[0][0] + p.height = batch.img_shape[0][1] preview_text = p.prompt -- cgit v1.2.3 From 669fb18d5222f53ae48abe0f30393d846c50ad91 Mon Sep 17 00:00:00 2001 From: dan Date: Sun, 8 Jan 2023 01:34:52 +0800 Subject: Add checkbox for variable training dims --- modules/hypernetworks/hypernetwork.py | 2 +- modules/textual_inversion/dataset.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 4 ++-- modules/ui.py | 3 +++ 4 files changed, 8 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b0cfbe71..dba52841 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -403,7 +403,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 375178ed..7f8a314f 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -29,7 +29,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token @@ -59,7 +59,7 @@ class PersonalizedBase(Dataset): raise Exception("interrupted") try: image = Image.open(path).convert('RGB') - if width < 2000: + if not varsize: image = image.resize((width, height), PIL.Image.BICUBIC) else: assert batch_size == 1, 'variable img size must have batch size 1' diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 9f96d0fd..110efd19 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -255,7 +255,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") @@ -309,7 +309,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) diff --git a/modules/ui.py b/modules/ui.py index 99483130..4e709a71 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1343,6 +1343,7 @@ def create_ui(): template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Ignore dimension settings and do not resize images", value=False, elem_id="train_varsize") steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") with FormRow(): @@ -1449,6 +1450,7 @@ def create_ui(): log_directory, training_width, training_height, + varsize, steps, clip_grad_mode, clip_grad_value, @@ -1480,6 +1482,7 @@ def create_ui(): log_directory, training_width, training_height, + varsize, steps, clip_grad_mode, clip_grad_value, -- cgit v1.2.3 From 72497895b9b1948f86d9309fe897cbb70c20ba7e Mon Sep 17 00:00:00 2001 From: dan Date: Sun, 8 Jan 2023 01:36:00 +0800 Subject: Move batchsize check --- modules/hypernetworks/hypernetwork.py | 2 +- modules/textual_inversion/dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index dba52841..32c67ccc 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -456,7 +456,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, pin_memory = shared.opts.pin_memory - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: saved_params = dict( diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 7f8a314f..bcad6848 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -46,6 +46,8 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" + if varsize: + assert batch_size == 1, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] @@ -61,8 +63,6 @@ class PersonalizedBase(Dataset): image = Image.open(path).convert('RGB') if not varsize: image = image.resize((width, height), PIL.Image.BICUBIC) - else: - assert batch_size == 1, 'variable img size must have batch size 1' except Exception: continue -- cgit v1.2.3 From 984b86dd0abf0da7f6b116864c791a2bfe8859ef Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 7 Jan 2023 13:08:21 -0700 Subject: Add fallback for Protocol import --- modules/sub_quadratic_attention.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index fea7aaac..93381bae 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,7 +15,13 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math -from typing import Optional, NamedTuple, Protocol, List + +try: + from typing import Protocol +except: + from typing_extensions import Protocol + +from typing import Optional, NamedTuple, List def narrow_trunc( input: Tensor, -- cgit v1.2.3 From a0c87f1fdf2b76b2ae4ef6c4b01ddaede3afab06 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 8 Jan 2023 08:52:26 +0300 Subject: skip images in embeddings dir if they have a second .preview extension --- modules/textual_inversion/textual_inversion.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 45882ed6..e85dd549 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -109,6 +109,10 @@ class EmbeddingDatabase: ext = ext.upper() if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': + return + embed_image = Image.open(path) if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: data = embedding_from_b64(embed_image.text['sd-ti-embedding']) -- cgit v1.2.3 From 085427de0efc9e9e7a6e9a5aebc6b5a69f0365e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 8 Jan 2023 09:37:33 +0300 Subject: make it possible for extensions/scripts to add their own embedding directories --- modules/sd_hijack.py | 7 +- modules/textual_inversion/textual_inversion.py | 170 +++++++++++++++---------- 2 files changed, 108 insertions(+), 69 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cfdb09d6..6b0d95af 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -83,10 +83,12 @@ class StableDiffusionModelHijack: clip = None optimization_method = None - embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) + embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() - def hijack(self, m): + def __init__(self): + self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) + def hijack(self, m): if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) @@ -117,7 +119,6 @@ class StableDiffusionModelHijack: self.layers = flatten(m) def undo_hijack(self, m): - if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: m.cond_stage_model = m.cond_stage_model.wrapped diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e85dd549..217fe9eb 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -66,17 +66,41 @@ class Embedding: return self.cached_checksum +class DirWithTextualInversionEmbeddings: + def __init__(self, path): + self.path = path + self.mtime = None + + def has_changed(self): + if not os.path.isdir(self.path): + return False + + mt = os.path.getmtime(self.path) + if self.mtime is None or mt > self.mtime: + return True + + def update(self): + if not os.path.isdir(self.path): + return + + self.mtime = os.path.getmtime(self.path) + + class EmbeddingDatabase: - def __init__(self, embeddings_dir): + def __init__(self): self.ids_lookup = {} self.word_embeddings = {} self.skipped_embeddings = {} - self.dir_mtime = None - self.embeddings_dir = embeddings_dir self.expected_shape = -1 + self.embedding_dirs = {} - def register_embedding(self, embedding, model): + def add_embedding_dir(self, path): + self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) + + def clear_embedding_dirs(self): + self.embedding_dirs.clear() + def register_embedding(self, embedding, model): self.word_embeddings[embedding.name] = embedding ids = model.cond_stage_model.tokenize([embedding.name])[0] @@ -93,69 +117,62 @@ class EmbeddingDatabase: vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) return vec.shape[1] - def load_textual_inversion_embeddings(self, force_reload = False): - mt = os.path.getmtime(self.embeddings_dir) - if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime: - return + def load_from_file(self, path, filename): + name, ext = os.path.splitext(filename) + ext = ext.upper() - self.dir_mtime = mt - self.ids_lookup.clear() - self.word_embeddings.clear() - self.skipped_embeddings.clear() - self.expected_shape = self.get_expected_shape() - - def process_file(path, filename): - name, ext = os.path.splitext(filename) - ext = ext.upper() - - if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: - _, second_ext = os.path.splitext(name) - if second_ext.upper() == '.PREVIEW': - return - - embed_image = Image.open(path) - if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: - data = embedding_from_b64(embed_image.text['sd-ti-embedding']) - name = data.get('name', name) - else: - data = extract_image_data_embed(embed_image) - name = data.get('name', name) - elif ext in ['.BIN', '.PT']: - data = torch.load(path, map_location="cpu") - else: + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': return - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - else: - raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") - - vec = emb.detach().to(devices.device, dtype=torch.float32) - embedding = Embedding(vec, name) - embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('sd_checkpoint', None) - embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - embedding.vectors = vec.shape[0] - embedding.shape = vec.shape[-1] - - if self.expected_shape == -1 or self.expected_shape == embedding.shape: - self.register_embedding(embedding, shared.sd_model) + embed_image = Image.open(path) + if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) + name = data.get('name', name) else: - self.skipped_embeddings[name] = embedding + data = extract_image_data_embed(embed_image) + name = data.get('name', name) + elif ext in ['.BIN', '.PT']: + data = torch.load(path, map_location="cpu") + else: + return + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vec.shape[0] + embedding.shape = vec.shape[-1] + + if self.expected_shape == -1 or self.expected_shape == embedding.shape: + self.register_embedding(embedding, shared.sd_model) + else: + self.skipped_embeddings[name] = embedding - for root, dirs, fns in os.walk(self.embeddings_dir): + def load_from_dir(self, embdir): + if not os.path.isdir(embdir.path): + return + + for root, dirs, fns in os.walk(embdir.path): for fn in fns: try: fullfn = os.path.join(root, fn) @@ -163,12 +180,32 @@ class EmbeddingDatabase: if os.stat(fullfn).st_size == 0: continue - process_file(fullfn, fn) + self.load_from_file(fullfn, fn) except Exception: print(f"Error loading embedding {fn}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) continue + def load_textual_inversion_embeddings(self, force_reload=False): + if not force_reload: + need_reload = False + for path, embdir in self.embedding_dirs.items(): + if embdir.has_changed(): + need_reload = True + break + + if not need_reload: + return + + self.ids_lookup.clear() + self.word_embeddings.clear() + self.skipped_embeddings.clear() + self.expected_shape = self.get_expected_shape() + + for path, embdir in self.embedding_dirs.items(): + self.load_from_dir(embdir) + embdir.update() + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if len(self.skipped_embeddings) > 0: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") @@ -251,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert os.path.isfile(template_file), "Prompt template file doesn't exist" assert steps, "Max steps is empty or 0" assert isinstance(steps, int), "Max steps must be integer" - assert steps > 0 , "Max steps must be positive" + assert steps > 0, "Max steps must be positive" assert isinstance(save_model_every, int), "Save {name} must be integer" - assert save_model_every >= 0 , "Save {name} must be positive or 0" + assert save_model_every >= 0, "Save {name} must be positive or 0" assert isinstance(create_image_every, int), "Create image must be integer" - assert create_image_every >= 0 , "Create image must be positive or 0" + assert create_image_every >= 0, "Create image must be positive or 0" if save_model_every or create_image_every: assert log_directory, "Log directory is empty" + def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 -- cgit v1.2.3 From 6d0cc1e239e0a43a2e6d696eae20c66fad0819bb Mon Sep 17 00:00:00 2001 From: noodleanon <122053346+noodleanon@users.noreply.github.com> Date: Sun, 8 Jan 2023 11:03:48 +0000 Subject: Corrected is_img2img param --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 0e8ea263..1785a6b4 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -151,7 +151,7 @@ class Api: def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): if txt2imgreq.script_name is not None: if scripts.scripts_txt2img.scripts == []: - scripts.scripts_txt2img.initialize_scripts(True) + scripts.scripts_txt2img.initialize_scripts(False) ui.create_ui() script_idx = script_name_to_index(txt2imgreq.script_name, scripts.scripts_txt2img.selectable_scripts) -- cgit v1.2.3 From 137ce534b2355a527cd1a50c192909161258b442 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 8 Jan 2023 16:14:38 +0300 Subject: remove some code duplication remove calls to locals() add a test for img2img with script --- modules/api/api.py | 33 ++++++++++++++++----------------- test/basic_features/img2img_test.py | 6 ++++++ 2 files changed, 22 insertions(+), 17 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 1785a6b4..5b6125f8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -148,14 +148,20 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) - def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - if txt2imgreq.script_name is not None: - if scripts.scripts_txt2img.scripts == []: - scripts.scripts_txt2img.initialize_scripts(False) - ui.create_ui() + def get_script(self, script_name, script_runner): + if script_name is None: + return None, None + + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + + script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) + script = script_runner.selectable_scripts[script_idx] + return script, script_idx - script_idx = script_name_to_index(txt2imgreq.script_name, scripts.scripts_txt2img.selectable_scripts) - script = scripts.scripts_txt2img.selectable_scripts[script_idx] + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): + script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), @@ -173,7 +179,7 @@ class Api: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) shared.state.begin() - if 'script' in locals(): + if script is not None: p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args @@ -182,7 +188,6 @@ class Api: processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) @@ -192,13 +197,7 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") - if img2imgreq.script_name is not None: - if scripts.scripts_img2img.scripts == []: - scripts.scripts_img2img.initialize_scripts(True) - ui.create_ui() - - script_idx = script_name_to_index(img2imgreq.script_name, scripts.scripts_img2img.selectable_scripts) - script = scripts.scripts_img2img.selectable_scripts[script_idx] + script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img) mask = img2imgreq.mask if mask: @@ -223,7 +222,7 @@ class Api: p.init_images = [decode_base64_to_image(x) for x in init_images] shared.state.begin() - if 'script' in locals(): + if script is not None: p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py index 0a9c1e8a..bd520b13 100644 --- a/test/basic_features/img2img_test.py +++ b/test/basic_features/img2img_test.py @@ -50,6 +50,12 @@ class TestImg2ImgWorking(unittest.TestCase): self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + def test_img2img_sd_upscale_performed(self): + self.simple_img2img["script_name"] = "sd upscale" + self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0] + + self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + if __name__ == "__main__": unittest.main() -- cgit v1.2.3 From cb255faec6e5f6b47b7632e6b7d450b9e2f6678b Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Sun, 8 Jan 2023 10:17:50 -0700 Subject: Add support for loading VAEs from safetensor files --- modules/sd_vae.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index ac71d62d..9fcfd9db 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,4 +1,5 @@ import torch +import safetensors.torch import os import collections from collections import namedtuple @@ -72,8 +73,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): candidates = [ *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), + *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True), *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True) + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True), ] if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): candidates.append(shared.cmd_opts.vae_path) @@ -137,6 +140,12 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"): if os.path.isfile(vae_file_try): vae_file = vae_file_try print(f"Using VAE found similar to selected model: {vae_file}") + # if still not found, try look for ".vae.safetensors" beside model + if vae_file == "auto": + vae_file_try = model_path + ".vae.safetensors" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print(f"Using VAE found similar to selected model: {vae_file}") # No more fallbacks for auto if vae_file == "auto": vae_file = None @@ -163,8 +172,14 @@ def load_vae(model, vae_file=None): assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") store_base_vae(model) - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + _, extension = os.path.splitext(vae_file) + if extension.lower() == ".safetensors": + vae_ckpt = safetensors.torch.load_file(vae_file, device=shared.weight_load_location) + else: + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + if "state_dict" in vae_ckpt: + vae_ckpt = vae_ckpt["state_dict"] + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} _load_vae_dict(model, vae_dict_1) if cache_enabled: -- cgit v1.2.3 From d4fd2418efb0986a8226add0b800fb5c73ffb58c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 14:57:47 +0300 Subject: add an option to use old hiresfix width/height behavior add a visual effect to inactive hires fix elements --- javascript/hires_fix.js | 25 +++++++++++++++++++++++++ modules/generation_parameters_copypaste.py | 17 +++++++++++------ modules/processing.py | 26 ++++++++++++++++++++++++-- modules/shared.py | 1 + modules/ui.py | 23 ++++++++++++++--------- style.css | 4 ++++ 6 files changed, 79 insertions(+), 17 deletions(-) create mode 100644 javascript/hires_fix.js (limited to 'modules') diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js new file mode 100644 index 00000000..07fba549 --- /dev/null +++ b/javascript/hires_fix.js @@ -0,0 +1,25 @@ + +function setInactive(elem, inactive){ + console.log(elem) + if(inactive){ + elem.classList.add('inactive') + } else{ + elem.classList.remove('inactive') + } +} + +function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){ + console.log(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) + + hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale') + hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x') + hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y') + + gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "" + + setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0) + setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0) + setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0) + + return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y] +} diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 12a9de3d..f7f68b67 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -197,6 +197,15 @@ def restore_old_hires_fix_params(res): firstpass_width = res.get('First pass size-1', None) firstpass_height = res.get('First pass size-2', None) + if shared.opts.use_old_hires_fix_width_height: + hires_width = int(res.get("Hires resize-1", None)) + hires_height = int(res.get("Hires resize-2", None)) + + if hires_width is not None and hires_height is not None: + res['Size-1'] = hires_width + res['Size-2'] = hires_height + return + if firstpass_width is None or firstpass_height is None: return @@ -205,12 +214,8 @@ def restore_old_hires_fix_params(res): height = int(res.get("Size-2", 512)) if firstpass_width == 0 or firstpass_height == 0: - # old algorithm for auto-calculating first pass size - desired_pixel_count = 512 * 512 - actual_pixel_count = width * height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) - firstpass_width = math.ceil(scale * width / 64) * 64 - firstpass_height = math.ceil(scale * height / 64) * 64 + from modules import processing + firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height) res['Size-1'] = firstpass_width res['Size-2'] = firstpass_height diff --git a/modules/processing.py b/modules/processing.py index 1d23b15f..f04a0e1e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -687,6 +687,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: return res +def old_hires_fix_first_pass_dimensions(width, height): + """old algorithm for auto-calculating first pass size""" + + desired_pixel_count = 512 * 512 + actual_pixel_count = width * height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + width = math.ceil(scale * width / 64) * 64 + height = math.ceil(scale * height / 64) * 64 + + return width, height + + class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None @@ -703,16 +715,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_upscale_to_y = hr_resize_y if firstphase_width != 0 or firstphase_height != 0: - print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) - self.hr_scale = self.width / firstphase_width + self.hr_upscale_to_x = self.width + self.hr_upscale_to_y = self.height self.width = firstphase_width self.height = firstphase_height self.truncate_x = 0 self.truncate_y = 0 + self.applied_old_hires_behavior_to = None def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): + self.hr_resize_x = self.width + self.hr_resize_y = self.height + self.hr_upscale_to_x = self.width + self.hr_upscale_to_y = self.height + + self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height) + self.applied_old_hires_behavior_to = (self.width, self.height) + if self.hr_resize_x == 0 and self.hr_resize_y == 0: self.extra_generation_params["Hires upscale"] = self.hr_scale self.hr_upscale_to_x = int(self.width * self.hr_scale) diff --git a/modules/shared.py b/modules/shared.py index a6712dae..a1e10201 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -398,6 +398,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), + "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { diff --git a/modules/ui.py b/modules/ui.py index 99483130..719c26b3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz with devices.autocast(): p.init([""], [0], [0]) - return f"resize: from {width}x{height} to {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" def apply_styles(prompt, prompt_neg, style1_name, style2_name): @@ -745,15 +745,20 @@ def create_ui(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - hr_resolution_preview_args = dict( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False - ) - for input in hr_resolution_preview_inputs: - input.change(**hr_resolution_preview_args) + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) diff --git a/style.css b/style.css index d796cbe9..ec5e4182 100644 --- a/style.css +++ b/style.css @@ -670,6 +670,10 @@ footer { min-width: auto; } +.inactive{ + opacity: 0.5; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 49c4509ce2302350210ff650fd26373518c46a79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 19:58:35 +0300 Subject: use existing function for loading VAE weights from file --- modules/sd_vae.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 9fcfd9db..0a49daa1 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -3,7 +3,7 @@ import safetensors.torch import os import collections from collections import namedtuple -from modules import shared, devices, script_callbacks +from modules import shared, devices, script_callbacks, sd_models from modules.paths import models_path import glob from copy import deepcopy @@ -172,13 +172,8 @@ def load_vae(model, vae_file=None): assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") store_base_vae(model) - _, extension = os.path.splitext(vae_file) - if extension.lower() == ".safetensors": - vae_ckpt = safetensors.torch.load_file(vae_file, device=shared.weight_load_location) - else: - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - if "state_dict" in vae_ckpt: - vae_ckpt = vae_ckpt["state_dict"] + + vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} _load_vae_dict(model, vae_dict_1) @@ -210,10 +205,12 @@ def _load_vae_dict(model, vae_dict_1): model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) + def clear_loaded_vae(): global loaded_vae_file loaded_vae_file = None + def reload_vae_weights(sd_model=None, vae_file="auto"): from modules import lowvram, devices, sd_hijack -- cgit v1.2.3 From cdfcbd995932ffa728db0cc00a5f97665c752103 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 20:08:48 +0300 Subject: Remove fallback for Protocol import and remove Protocol import and remove instances of Protocol in code add some whitespace between functions to be in line with other code in the repo --- modules/sub_quadratic_attention.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 93381bae..55052815 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,14 +15,9 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math - -try: - from typing import Protocol -except: - from typing_extensions import Protocol - from typing import Optional, NamedTuple, List + def narrow_trunc( input: Tensor, dim: int, @@ -31,12 +26,14 @@ def narrow_trunc( ) -> Tensor: return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) + class AttnChunk(NamedTuple): exp_values: Tensor exp_weights_sum: Tensor max_score: Tensor -class SummarizeChunk(Protocol): + +class SummarizeChunk: @staticmethod def __call__( query: Tensor, @@ -44,7 +41,8 @@ class SummarizeChunk(Protocol): value: Tensor, ) -> AttnChunk: ... -class ComputeQueryChunkAttn(Protocol): + +class ComputeQueryChunkAttn: @staticmethod def __call__( query: Tensor, @@ -52,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol): value: Tensor, ) -> Tensor: ... + def _summarize_chunk( query: Tensor, key: Tensor, @@ -72,6 +71,7 @@ def _summarize_chunk( max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + def _query_chunk_attention( query: Tensor, key: Tensor, @@ -112,6 +112,7 @@ def _query_chunk_attention( all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) return all_values / all_weights + # TODO: refactor CrossAttention#get_attention_scores to share code with this def _get_attention_scores_no_kv_chunking( query: Tensor, @@ -131,10 +132,12 @@ def _get_attention_scores_no_kv_chunking( hidden_states_slice = torch.bmm(attn_probs, value) return hidden_states_slice + class ScannedChunk(NamedTuple): chunk_idx: int attn_chunk: AttnChunk + def efficient_dot_product_attention( query: Tensor, key: Tensor, -- cgit v1.2.3 From 43bb5190fc9e7ae479a5dc6640be202c9a71e464 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 22:52:23 +0300 Subject: remove/simplify some changes from #6481 --- modules/textual_inversion/dataset.py | 14 +++++--------- modules/textual_inversion/textual_inversion.py | 4 ++-- modules/ui.py | 2 +- 3 files changed, 8 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index bcad6848..fa48708e 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): self.filename = filename self.filename_text = filename_text self.latent_dist = latent_dist @@ -25,7 +25,6 @@ class DatasetEntry: self.cond = cond self.cond_text = cond_text self.pixel_values = pixel_values - self.img_shape = img_shape class PersonalizedBase(Dataset): @@ -46,12 +45,10 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - if varsize: - assert batch_size == 1, 'variable img size must have batch size 1' + assert batch_size == 1 or not varsize, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] - self.shuffle_tags = shuffle_tags self.tag_drop_out = tag_drop_out @@ -91,14 +88,14 @@ class PersonalizedBase(Dataset): if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) latent_sampling_method = "once" - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) elif latent_sampling_method == "deterministic": # Works only for DiagonalGaussianDistribution latent_dist.std = 0 latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) elif latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) @@ -154,7 +151,6 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) - self.img_shape = [entry.img_shape for entry in data] #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ad76297e..14be2c96 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -492,8 +492,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ else: p.prompt = batch.cond_text[0] p.steps = 20 - p.width = batch.img_shape[0][0] - p.height = batch.img_shape[0][1] + p.width = training_width + p.height = training_height preview_text = p.prompt diff --git a/modules/ui.py b/modules/ui.py index 9d6b141e..ddfe1b1a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1348,7 +1348,7 @@ def create_ui(): template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Ignore dimension settings and do not resize images", value=False, elem_id="train_varsize") + varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") with FormRow(): -- cgit v1.2.3 From 1fbb6f9ebe48326a3b12ecf611105dbc4a46891e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 23:35:40 +0300 Subject: make a dropdown for prompt template selection --- modules/hypernetworks/hypernetwork.py | 7 ++++-- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 35 ++++++++++++++++++++------ modules/ui.py | 11 ++++++-- webui.py | 3 +++ 5 files changed, 45 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 32c67ccc..ea3f1db9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,6 +24,7 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} + class HypernetworkModule(torch.nn.Module): multiplier = 1.0 activation_dict = { @@ -403,13 +404,15 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 - textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + template_file = textual_inversion.textual_inversion_templates.get(template_filename, None) + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + template_file = template_file.path path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() diff --git a/modules/shared.py b/modules/shared.py index a1e10201..aa37c8ce 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -33,6 +33,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 14be2c96..5420903f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -2,6 +2,7 @@ import os import sys import traceback import inspect +from collections import namedtuple import torch import tqdm @@ -15,12 +16,26 @@ from modules import shared, devices, sd_hijack, processing, sd_models, images, s import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler -from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, - insert_image_data_embed, extract_image_data_embed, - caption_image_overlay) +from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay from modules.textual_inversion.logging import save_settings_to_file +TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) +textual_inversion_templates = {} + + +def list_textual_inversion_templates(): + textual_inversion_templates.clear() + + for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): + for fn in fns: + path = os.path.join(root, fn) + + textual_inversion_templates[fn] = TextualInversionTemplate(fn, path) + + return textual_inversion_templates + + class Embedding: def __init__(self, vec, name, step=None): self.vec = vec @@ -274,7 +289,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) -def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): +def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" assert isinstance(batch_size, int), "Batch size must be integer" @@ -284,8 +299,9 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert data_root, "Dataset directory is empty" assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - assert template_file, "Prompt template file is empty" - assert os.path.isfile(template_file), "Prompt template file doesn't exist" + assert template_filename, "Prompt template file not selected" + assert template_file, f"Prompt template file {template_filename} not found" + assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist" assert steps, "Max steps is empty or 0" assert isinstance(steps, int), "Max steps must be integer" assert steps > 0, "Max steps must be positive" @@ -296,10 +312,13 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 - validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = textual_inversion_templates.get(template_filename, None) + validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = template_file.path shared.state.job = "train-embedding" shared.state.textinfo = "Initializing textual inversion training..." diff --git a/modules/ui.py b/modules/ui.py index ddfe1b1a..b6079aec 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -37,7 +37,7 @@ from modules import prompt_parser from modules.images import save_image from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img -import modules.textual_inversion.ui +from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text @@ -1322,6 +1322,9 @@ def create_ui(): outputs=[process_focal_crop_row], ) + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") with FormRow(): @@ -1345,7 +1348,11 @@ def create_ui(): dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") diff --git a/webui.py b/webui.py index 8737e593..47d372c7 100644 --- a/webui.py +++ b/webui.py @@ -33,6 +33,7 @@ import modules.sd_models import modules.sd_vae import modules.txt2img import modules.script_callbacks +import modules.textual_inversion.textual_inversion import modules.ui from modules import modelloader @@ -67,6 +68,8 @@ def initialize(): modules.sd_vae.refresh_vae_list() + modules.textual_inversion.textual_inversion.list_textual_inversion_templates() + try: modules.sd_models.load_model() except Exception as e: -- cgit v1.2.3 From 95727312ca5913876aa1c74f47d1ff6d93bb6b1f Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 9 Jan 2023 16:54:12 -0500 Subject: remove bytes -> gb conversion --- modules/api/api.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index d2222b18..1c121ff0 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -467,26 +467,24 @@ class Api: return TrainResponse(info = "train embedding error: {error}".format(error = error)) def get_memory(self): - def gb(val: float): - return round(val / 1024 / 1024 / 1024, 2) try: import os, psutil process = psutil.Process(os.getpid()) - res = process.memory_info() - ram_total = 100 * res.rss / process.memory_percent() - ram = { 'free': gb(ram_total - res.rss), 'used': gb(res.rss), 'total': gb(ram_total) } + res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values + ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe + ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total } except Exception as err: ram = { 'error': f'{err}' } try: import torch if torch.cuda.is_available(): s = torch.cuda.mem_get_info() - system = { 'free': gb(s[0]), 'used': gb(s[1] - s[0]), 'total': gb(s[1]) } + system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] } s = dict(torch.cuda.memory_stats(shared.device)) - allocated = { 'current': gb(s['allocated_bytes.all.current']), 'peak': gb(s['allocated_bytes.all.peak']) } - reserved = { 'current': gb(s['reserved_bytes.all.current']), 'peak': gb(s['reserved_bytes.all.peak']) } - active = { 'current': gb(s['active_bytes.all.current']), 'peak': gb(s['active_bytes.all.peak']) } - inactive = { 'current': gb(s['inactive_split_bytes.all.current']), 'peak': gb(s['inactive_split_bytes.all.peak']) } + allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] } + reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] } + active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] } + inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] } warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] } cuda = { 'system': system, -- cgit v1.2.3 From 3fe9e9e54dcfc41d7c5ee6976f83b0de29fd3dda Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 02:17:33 +0300 Subject: fix broken resolution detection when pasting parameters with old hires fix enabled --- modules/generation_parameters_copypaste.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index f7f68b67..620aa606 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -198,10 +198,10 @@ def restore_old_hires_fix_params(res): firstpass_height = res.get('First pass size-2', None) if shared.opts.use_old_hires_fix_width_height: - hires_width = int(res.get("Hires resize-1", None)) - hires_height = int(res.get("Hires resize-2", None)) + hires_width = int(res.get("Hires resize-1", 0)) + hires_height = int(res.get("Hires resize-2", 0)) - if hires_width is not None and hires_height is not None: + if hires_width and hires_height: res['Size-1'] = hires_width res['Size-2'] = hires_height return -- cgit v1.2.3 From 552d7b90bf483c160cd20740f7acd7fccbc02e6f Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 9 Jan 2023 18:34:26 -0500 Subject: allow model load if previous model failed --- modules/sd_models.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 76a89e88..0a6d55ca 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -49,6 +49,9 @@ def checkpoint_tiles(): def find_checkpoint_config(info): + if info is None: + return shared.cmd_opts.config + config = os.path.splitext(info.filename)[0] + ".yaml" if os.path.exists(config): return config @@ -345,14 +348,16 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model + if sd_model is None: # previous model load failed + current_checkpoint_info = None + else: + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: + return - current_checkpoint_info = sd_model.sd_checkpoint_info checkpoint_config = find_checkpoint_config(current_checkpoint_info) - if sd_model.sd_model_checkpoint == checkpoint_info.filename: - return - - if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.3 From 2275f130bfe437c3245a66559f92af94d0e4d8ff Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 9 Jan 2023 21:23:58 -0500 Subject: relax reponse type check enforcement --- modules/api/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index 880edde6..034b4aa0 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -262,5 +262,5 @@ class EmbeddingsResponse(BaseModel): skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") class MemoryResponse(BaseModel): - ram: dict[str, str] | dict[str, float] = Field(title="RAM", description="System memory stats") - cuda: dict[str, str] | dict[str, dict] = Field(title="CUDA", description="nVidia CUDA memory stats") + ram: dict = Field(title="RAM", description="System memory stats") + cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") -- cgit v1.2.3 From a4a5475cfa3c68af6cb046081002a72f862ce4be Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Tue, 10 Jan 2023 14:56:57 +0900 Subject: Variable dropout rate Implements variable dropout rate from #4549 Fixes hypernetwork multiplier being able to modified during training, also fixes user-errors by setting multiplier value to lower values for training. Changes function name to match torch.nn.module standard Fixes RNG reset issue when generating previews by restoring RNG state --- modules/hypernetworks/hypernetwork.py | 101 +++++++++++++++++++++++++--------- modules/hypernetworks/ui.py | 4 +- modules/ui.py | 4 +- 3 files changed, 81 insertions(+), 28 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index ea3f1db9..300d3975 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -39,7 +39,7 @@ class HypernetworkModule(torch.nn.Module): activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', - add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False): + add_layer_norm=False, activate_output=False, dropout_structure=None): super().__init__() assert layer_structure is not None, "layer_structure must not be None" @@ -64,9 +64,12 @@ class HypernetworkModule(torch.nn.Module): if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - # Add dropout except last layer - if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2): - linears.append(torch.nn.Dropout(p=0.3)) + # Everything should be now parsed into dropout structure, and applied here. + # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0. + if dropout_structure is not None and dropout_structure[i+1] > 0: + assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!" + linears.append(torch.nn.Dropout(p=dropout_structure[i+1])) + # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0]. self.linear = torch.nn.Sequential(*linears) @@ -113,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x def forward(self, x): - return x + self.linear(x) * self.multiplier + return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1) def trainables(self): layer_structure = [] @@ -126,6 +129,21 @@ class HypernetworkModule(torch.nn.Module): def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength +#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check. +def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout): + if layer_structure is None: + layer_structure = [1, 2, 1] + if not use_dropout: + return [0] * len(layer_structure) + dropout_values = [0] + dropout_values.extend([0.3] * (len(layer_structure) - 3)) + if last_layer_dropout: + dropout_values.append(0.3) + else: + dropout_values.append(0) + dropout_values.append(0) + return dropout_values + class Hypernetwork: filename = None @@ -144,18 +162,22 @@ class Hypernetwork: self.add_layer_norm = add_layer_norm self.use_dropout = use_dropout self.activate_output = activate_output - self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True + self.last_layer_dropout = kwargs.get('last_layer_dropout', True) + self.dropout_structure = kwargs.get('dropout_structure', None) + if self.dropout_structure is None: + self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) self.optimizer_name = None self.optimizer_state_dict = None + self.optional_info = None for size in enable_sizes or []: self.layers[size] = ( HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure), ) - self.eval_mode() + self.eval() def weights(self): res = [] @@ -164,14 +186,14 @@ class Hypernetwork: res += layer.parameters() return res - def train_mode(self): + def train(self, mode=True): for k, layers in self.layers.items(): for layer in layers: - layer.train() + layer.train(mode=mode) for param in layer.parameters(): - param.requires_grad = True + param.requires_grad = mode - def eval_mode(self): + def eval(self): for k, layers in self.layers.items(): for layer in layers: layer.eval() @@ -191,11 +213,13 @@ class Hypernetwork: state_dict['activation_func'] = self.activation_func state_dict['is_layer_norm'] = self.add_layer_norm state_dict['weight_initialization'] = self.weight_init - state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['activate_output'] = self.activate_output - state_dict['last_layer_dropout'] = self.last_layer_dropout + state_dict['use_dropout'] = self.use_dropout + state_dict['dropout_structure'] = self.dropout_structure + state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout + state_dict['optional_info'] = self.optional_info if self.optional_info else None if self.optimizer_name is not None: optimizer_saved_dict['optimizer_name'] = self.optimizer_name @@ -215,43 +239,56 @@ class Hypernetwork: self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) print(self.layer_structure) + optional_info = state_dict.get('optional_info', None) + if optional_info is not None: + print(f"INFO:\n {optional_info}\n") + self.optional_info = optional_info self.activation_func = state_dict.get('activation_func', None) print(f"Activation function is {self.activation_func}") self.weight_init = state_dict.get('weight_initialization', 'Normal') print(f"Weight initialization is {self.weight_init}") self.add_layer_norm = state_dict.get('is_layer_norm', False) print(f"Layer norm is set to {self.add_layer_norm}") - self.use_dropout = state_dict.get('use_dropout', False) + self.dropout_structure = state_dict.get('dropout_structure', None) + self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False) print(f"Dropout usage is set to {self.use_dropout}" ) self.activate_output = state_dict.get('activate_output', True) print(f"Activate last layer is set to {self.activate_output}") self.last_layer_dropout = state_dict.get('last_layer_dropout', False) + # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0. + if self.dropout_structure is None: + print("Using previous dropout structure") + self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) + print(f"Dropout structure is set to {self.dropout_structure}") optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} - self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') - print(f"Optimizer name is {self.optimizer_name}") + if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None): self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) else: self.optimizer_state_dict = None if self.optimizer_state_dict: + self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') print("Loaded existing optimizer from checkpoint") + print(f"Optimizer name is {self.optimizer_name}") else: + self.optimizer_name = "AdamW" print("No saved optimizer exists in checkpoint") for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, self.dropout_structure), HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, self.dropout_structure), ) self.name = state_dict.get('name', self.name) self.step = state_dict.get('step', 0) self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) + self.eval() def list_hypernetworks(path): @@ -379,9 +416,10 @@ def report_statistics(loss_info:dict): print(e) -def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) + assert name, "Name cannot be empty!" fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") if not overwrite_old: @@ -390,6 +428,11 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, if type(layer_structure) == str: layer_structure = [float(x.strip()) for x in layer_structure.split(",")] + if use_dropout and dropout_structure and type(dropout_structure) == str: + dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")] + else: + dropout_structure = [0] * len(layer_structure) + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( name=name, enable_sizes=[int(x) for x in enable_sizes], @@ -398,6 +441,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, weight_init=weight_init, add_layer_norm=add_layer_norm, use_dropout=use_dropout, + dropout_structure=dropout_structure ) hypernet.save(fn) @@ -480,7 +524,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, shared.sd_model.first_stage_model.to(devices.cpu) weights = hypernetwork.weights() - hypernetwork.train_mode() + hypernetwork.train() # Here we use optimizer from saved HN, or we can specify as UI option. if hypernetwork.optimizer_name in optimizer_dict: @@ -594,7 +638,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if images_dir is not None and steps_done % create_image_every == 0: forced_filename = f'{hypernetwork_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) - hypernetwork.eval_mode() + hypernetwork.eval() + rng_state = torch.get_rng_state() + cuda_rng_state = None + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state_all() shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) @@ -627,7 +675,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - hypernetwork.train_mode() + torch.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state_all(cuda_rng_state) + hypernetwork.train() if image is not None: shared.state.current_image = image last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) @@ -649,7 +700,7 @@ Last saved image: {html.escape(last_saved_image)}
finally: pbar.leave = False pbar.close() - hypernetwork.eval_mode() + hypernetwork.eval() #report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index e7f9e593..81e3f519 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,8 +9,8 @@ from modules import devices, sd_hijack, shared not_available = ["hardswish", "multiheadattention"] keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) -def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): - filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout) +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): + filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", "" diff --git a/modules/ui.py b/modules/ui.py index b6079aec..9b9081b5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1268,6 +1268,7 @@ def create_ui(): new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") with gr.Row(): @@ -1414,7 +1415,8 @@ def create_ui(): new_hypernetwork_activation_func, new_hypernetwork_initialization_option, new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout + new_hypernetwork_use_dropout, + new_hypernetwork_dropout_structure ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From e9f8292a3a6792b722696fcf8e32b3fcb43ba436 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Tue, 10 Jan 2023 11:54:48 +0300 Subject: Split history ui.py to ui_progress.py --- modules/ui.py | 1928 ------------------------------------------------ modules/ui_progress.py | 1928 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1928 insertions(+), 1928 deletions(-) delete mode 100644 modules/ui.py create mode 100644 modules/ui_progress.py (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py deleted file mode 100644 index 9b9081b5..00000000 --- a/modules/ui.py +++ /dev/null @@ -1,1928 +0,0 @@ -import html -import json -import math -import mimetypes -import os -import platform -import random -import subprocess as sp -import sys -import tempfile -import time -import traceback -from functools import partial, reduce - -import gradio as gr -import gradio.routes -import gradio.utils -import numpy as np -from PIL import Image, PngImagePlugin -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call - -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru -from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path - -from modules.shared import opts, cmd_opts, restricted_opts - -import modules.codeformer_model -import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.scripts -import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui -from modules import prompt_parser -from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.textual_inversion import textual_inversion -import modules.hypernetworks.ui -from modules.generation_parameters_copypaste import image_from_url_text - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -if not cmd_opts.share and not cmd_opts.listen: - # fix gradio phoning home - gradio.utils.version_check = lambda: None - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' - -if cmd_opts.ngrok is not None: - import modules.ngrok as ngrok - print('ngrok authtoken detected, trying to connect...') - ngrok.connect( - cmd_opts.ngrok, - cmd_opts.port if cmd_opts.port is not None else 7860, - cmd_opts.ngrok_region - ) - - -def gr_show(visible=True): - return {"visible": visible, "__type__": "update"} - - -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None - -css_hide_progressbar = """ -.wrap .m-12 svg { display:none!important; } -.wrap .m-12::before { content:"Loading..." } -.wrap .z-20 svg { display:none!important; } -.wrap .z-20::before { content:"Loading..." } -.progress-bar { display:none!important; } -.meta-text { display:none!important; } -.meta-text-center { display:none!important; } -""" - -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -random_symbol = '\U0001f3b2\ufe0f' # 🎲️ -reuse_symbol = '\u267b\ufe0f' # ♻️ -paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ - - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text - -def send_gradio_gallery_to_image(x): - if len(x) == 0: - return None - return image_from_url_text(x[0]) - -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr_show(False) - preview_visibility = gr_show(False) - - if opts.show_progress_every_n_steps != 0: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr_show(True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr_show(False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def visit(x, func, path=""): - if hasattr(x, 'children'): - for c in x.children: - visit(c, func, path) - elif x.label is not None: - func(path + "/" + str(x.label), x) - - -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] - - -def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): - from modules import processing, devices - - if not enable: - return "" - - p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) - - with devices.autocast(): - p.init([""], [0], [0]) - - return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" - - -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] - - -def interrogate(image): - prompt = shared.interrogator.interrogate(image.convert("RGB")) - - return gr_show(True) if prompt is None else prompt - - -def interrogate_deepbooru(image): - prompt = deepbooru.model.tag(image) - return gr_show(True) if prompt is None else prompt - - -def create_seed_inputs(target_interface): - with FormRow(elem_id=target_interface + '_seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') - seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') - - with gr.Group(elem_id=target_interface + '_subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) - - # Components to show/hide based on the 'Extra' checkbox - seed_extras = [] - - with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: - seed_extras.append(seed_extra_row_1) - subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') - subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') - - with FormRow(visible=False) as seed_extra_row_2: - seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') - - random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) - random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) - - def change_visibility(show): - return {comp: gr_show(show) for comp in seed_extras} - - seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) - - return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - - - -def connect_clear_prompt(button): - """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" - button.click( - _js="clear_prompt", - fn=None, - inputs=[], - outputs=[], - ) - - -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): - """ Connects a 'reuse (sub)seed' button's click event so that it copies last used - (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength - was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" - def copy_seed(gen_info_string: str, index): - res = -1 - - try: - gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError as e: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) - - return [res, gr_show(False)] - - reuse_seed.click( - fn=copy_seed, - _js="(x, y) => [x, selected_gallery_index()]", - show_progress=False, - inputs=[generation_info, dummy_component], - outputs=[seed, dummy_component] - ) - - -def update_token_counter(text, steps): - try: - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] - - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" - - -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - - with gr.Row(elem_id="toprow"): - with gr.Column(scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, - placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, - placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_id="interrogate_col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1): - with gr.Row(): - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - -def apply_setting(key, value): - if value is None: - return gr.update() - - if shared.cmd_opts.freeze_settings: - return gr.update() - - # dont allow model to be swapped when model hash exists in prompt - if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: - return gr.update() - - if key == "sd_model_checkpoint": - ckpt_info = sd_models.get_closet_checkpoint_match(value) - - if ckpt_info is not None: - value = ckpt_info.title - else: - return gr.update() - - comp_args = opts.data_labels[key].component_args - if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: - return - - valtype = type(opts.data_labels[key].default) - oldval = opts.data.get(key, None) - opts.data[key] = valtype(value) if valtype != type(None) else value - if oldval != value and opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() - - opts.save(shared.config_filename) - return value - - -def update_generation_info(args): - generation_info, html_info, img_index = args - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info - - -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - -def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel'): - with gr.Group(): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log - - -def create_sampler_and_steps_selection(choices, tabname): - if opts.samplers_in_dropdown: - with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - else: - with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - - return steps, sampler_index - - -def ordered_ui_categories(): - user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} - - for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): - yield category - - -def create_ui(): - import modules.img2img - import modules.txt2img - - reload_javascript() - - parameters_copypaste.reset() - - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) - - with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) - - dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel', elem_id="txt2img_settings"): - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "cfg": - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - - elif category == "checkboxes": - with FormRow(elem_id="txt2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) - - elif category == "hires_fix": - with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: - with FormRow(elem_id="txt2img_hires_fix_row1"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - - with FormRow(elem_id="txt2img_hires_fix_row2"): - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - - hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - for input in hr_resolution_preview_inputs: - input.change( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False, - ) - input.change( - None, - _js="onCalcResolutionHires", - inputs=hr_resolution_preview_inputs, - outputs=[], - show_progress=False, - ) - - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - txt2img_args = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), - _js="submit", - inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) - - txt_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ] - ) - - enable_hr.change( - fn=lambda x: gr_show(x), - inputs=[enable_hr], - outputs=[hr_options], - show_progress = False, - ) - - txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - *modules.scripts.scripts_txt2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) - - txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, - steps, - sampler_index, - cfg_scale, - seed, - width, - height, - ] - - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) - - with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') - - with FormRow().style(equal_height=False): - with gr.Column(variant='panel', elem_id="img2img_settings"): - - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) - - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) - - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") - - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "cfg": - with FormGroup(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') - - elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() - - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ] - ) - - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - - img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), - _js="submit_img2img", - inputs=[ - dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, - init_img, - init_img_with_mask, - init_img_with_mask_orig, - init_img_inpaint, - init_mask_inpaint, - mask_mode, - steps, - sampler_index, - mask_blur, - mask_alpha, - inpainting_fill, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - denoising_strength, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - ] + custom_inputs, - outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) - - img2img_interrogate.click( - fn=interrogate, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], - ) - - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - - img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (mask_blur, "Mask blur"), - *modules.scripts.scripts_img2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) - - modules.scripts.scripts_current = None - - with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=modules.extras.clear_cache, - inputs=[], outputs=[] - ) - - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") - - with gr.Column(variant='panel'): - html = gr.HTML() - generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") - html2 = gr.HTML() - with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) - - image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), - inputs=[image], - outputs=[html, generation_info, html2], - ) - - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") - - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") - - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") - - custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - - with gr.Row(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") - save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): - gr.HTML(value="

See wiki for detailed explanation.

") - - with gr.Row().style(equal_height=False): - with gr.Tabs(elem_id="train_tabs"): - - with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") - initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") - - with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") - new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - - with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - def get_textual_inversion_template_names(): - return sorted([x for x in textual_inversion.textual_inversion_templates]) - - with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with FormRow(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - - with FormRow(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - - with FormRow(): - clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) - clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) - - with FormRow(): - batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") - - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - - with FormRow(): - template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) - create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") - - training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") - training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") - steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") - - with FormRow(): - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") - - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") - - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") - - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") - - with gr.Row(): - train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") - interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") - - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) - - script_callbacks.ui_train_tabs_callback(params) - - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) - ti_progress = gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) - - create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout, - new_hypernetwork_dropout_structure - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_embedding_name, - embedding_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {str(t)} for key {key}') - - elem_id = "setting_"+key - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - with FormRow(): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - - components = [] - component_dict = {} - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return gr.update(value=value), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - with gr.Row(): - with gr.Column(scale=6): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - with gr.Column(): - restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") - - result = gr.HTML(elem_id="settings_result") - - quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} - - quicksettings_list = [] - - previous_section = None - current_tab = None - with gr.Tabs(elem_id="settings"): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - elem_id, text = item.section - - if current_tab is not None: - current_tab.__exit__() - - current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) - current_tab.__enter__() - - previous_section = item.section - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - - if current_tab is not None: - current_tab.__exit__() - - with gr.TabItem("Actions"): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") - - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - - restart_gradio.click( - fn=request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) - - interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "ti"), - ] - - css = "" - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - with open(cssfile, "r", encoding="utf8") as file: - css += file.read() + "\n" - - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - css += file.read() + "\n" - - if not cmd_opts.no_progressbar_hiding: - css += css_hide_progressbar - - interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] - - extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] - - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings"): - for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component - - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() - - with gr.Tabs(elem_id="tabs") as tabs: - for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): - interface.render() - - if os.path.exists(os.path.join(script_path, "notification.mp3")): - audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") - - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for i, k, item in quicksettings_list: - component = component_dict[k] - - component.change( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [getattr(opts, key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - ) - - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - custom_name, - checkpoint_format, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) - - ui_config_file = cmd_opts.ui_config_file - ui_settings = {} - settings_count = len(ui_settings) - error_loading = False - - try: - if os.path.exists(ui_config_file): - with open(ui_config_file, "r", encoding="utf8") as file: - ui_settings = json.load(file) - except Exception: - error_loading = True - print("Error loading settings:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - def loadsave(path, x): - def apply_field(obj, field, condition=None, init_field=None): - key = path + "/" + field - - if getattr(obj, 'custom_script_source', None) is not None: - key = 'customscript/' + obj.custom_script_source + '/' + key - - if getattr(obj, 'do_not_save_to_config', False): - return - - saved_value = ui_settings.get(key, None) - if saved_value is None: - ui_settings[key] = getattr(obj, field) - elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') - else: - setattr(obj, field, saved_value) - if init_field is not None: - init_field(saved_value) - - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: - apply_field(x, 'visible') - - if type(x) == gr.Slider: - apply_field(x, 'value') - apply_field(x, 'minimum') - apply_field(x, 'maximum') - apply_field(x, 'step') - - if type(x) == gr.Radio: - apply_field(x, 'value', lambda val: val in x.choices) - - if type(x) == gr.Checkbox: - apply_field(x, 'value') - - if type(x) == gr.Textbox: - apply_field(x, 'value') - - if type(x) == gr.Number: - apply_field(x, 'value') - - if type(x) == gr.Dropdown: - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - - visit(txt2img_interface, loadsave, "txt2img") - visit(img2img_interface, loadsave, "img2img") - visit(extras_interface, loadsave, "extras") - visit(modelmerger_interface, loadsave, "modelmerger") - visit(train_interface, loadsave, "train") - - if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): - with open(ui_config_file, "w", encoding="utf8") as file: - json.dump(ui_settings, file, indent=4) - - return demo - - -def reload_javascript(): - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' - - scripts_list = modules.scripts.list_scripts("javascript", ".js") - - for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" - - if cmd_opts.theme is not None: - javascript += f"\n\n" - - javascript += f"\n" - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse - - -def versions_html(): - import torch - import launch - - python_version = ".".join([str(x) for x in sys.version_info[0:3]]) - commit = launch.commit_hash() - short_commit = commit[0:8] - - if shared.xformers_available: - import xformers - xformers_version = xformers.__version__ - else: - xformers_version = "N/A" - - return f""" -python: {python_version} - •  -torch: {torch.__version__} - •  -xformers: {xformers_version} - •  -gradio: {gr.__version__} - •  -commit: {short_commit} -""" diff --git a/modules/ui_progress.py b/modules/ui_progress.py new file mode 100644 index 00000000..9b9081b5 --- /dev/null +++ b/modules/ui_progress.py @@ -0,0 +1,1928 @@ +import html +import json +import math +import mimetypes +import os +import platform +import random +import subprocess as sp +import sys +import tempfile +import time +import traceback +from functools import partial, reduce + +import gradio as gr +import gradio.routes +import gradio.utils +import numpy as np +from PIL import Image, PngImagePlugin +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call + +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML +from modules.paths import script_path + +from modules.shared import opts, cmd_opts, restricted_opts + +import modules.codeformer_model +import modules.generation_parameters_copypaste as parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.scripts +import modules.shared as shared +import modules.styles +import modules.textual_inversion.ui +from modules import prompt_parser +from modules.images import save_image +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img +from modules.textual_inversion import textual_inversion +import modules.hypernetworks.ui +from modules.generation_parameters_copypaste import image_from_url_text + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +if not cmd_opts.share and not cmd_opts.listen: + # fix gradio phoning home + gradio.utils.version_check = lambda: None + gradio.utils.get_local_ip_address = lambda: '127.0.0.1' + +if cmd_opts.ngrok is not None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect( + cmd_opts.ngrok, + cmd_opts.port if cmd_opts.port is not None else 7860, + cmd_opts.ngrok_region + ) + + +def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + +sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None + +css_hide_progressbar = """ +.wrap .m-12 svg { display:none!important; } +.wrap .m-12::before { content:"Loading..." } +.wrap .z-20 svg { display:none!important; } +.wrap .z-20::before { content:"Loading..." } +.progress-bar { display:none!important; } +.meta-text { display:none!important; } +.meta-text-center { display:none!important; } +""" + +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +random_symbol = '\U0001f3b2\ufe0f' # 🎲️ +reuse_symbol = '\u267b\ufe0f' # ♻️ +paste_symbol = '\u2199\ufe0f' # ↙ +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 +clear_prompt_symbol = '\U0001F5D1' # 🗑️ + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + return image_from_url_text(x[0]) + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = opts.outdir_save + save_to_dirs = opts.use_save_to_dirs_for_ui + extension: str = opts.samples_format + start_index = 0 + + if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(opts.outdir_save, exist_ok=True) + + with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def calc_time_left(progress, threshold, label, force_display, show_eta): + if progress == 0: + return "" + else: + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + eta_relative = eta-time_since_start + if (eta_relative > threshold and show_eta) or force_display: + if eta_relative > 3600: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + elif eta_relative > 60: + return label + time.strftime('%M:%S', time.gmtime(eta_relative)) + else: + return label + time.strftime('%Ss', time.gmtime(eta_relative)) + else: + return "" + + +def check_progress_call(id_part): + if shared.state.job_count == 0: + return "", gr_show(False), gr_show(False), gr_show(False) + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + # Show progress percentage and time left at the same moment, and base it also on steps done + show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 + + time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) + if time_left != "": + shared.state.time_left_force_display = True + + progress = min(progress, 1) + + progressbar = "" + if opts.show_progressbar: + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" + + image = gr_show(False) + preview_visibility = gr_show(False) + + if opts.show_progress_every_n_steps != 0: + shared.state.set_current_image() + image = shared.state.current_image + + if image is None: + image = gr.update(value=None) + else: + preview_visibility = gr_show(True) + + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result + + +def check_progress_call_initial(id_part): + shared.state.job_count = -1 + shared.state.current_latent = None + shared.state.current_image = None + shared.state.textinfo = None + shared.state.time_start = time.time() + shared.state.time_left_force_display = False + + return check_progress_call(id_part) + + +def visit(x, func, path=""): + if hasattr(x, 'children'): + for c in x.children: + visit(c, func, path) + elif x.label is not None: + func(path + "/" + str(x.label), x) + + +def add_style(name: str, prompt: str, negative_prompt: str): + if name is None: + return [gr_show() for x in range(4)] + + style = modules.styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we + # reserialize all styles every time we save them + shared.prompt_styles.save_styles(shared.styles_filename) + + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] + + +def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): + from modules import processing, devices + + if not enable: + return "" + + p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + + with devices.autocast(): + p.init([""], [0], [0]) + + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" + + +def apply_styles(prompt, prompt_neg, style1_name, style2_name): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] + + +def interrogate(image): + prompt = shared.interrogator.interrogate(image.convert("RGB")) + + return gr_show(True) if prompt is None else prompt + + +def interrogate_deepbooru(image): + prompt = deepbooru.model.tag(image) + return gr_show(True) if prompt is None else prompt + + +def create_seed_inputs(target_interface): + with FormRow(elem_id=target_interface + '_seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') + seed.style(container=False) + random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') + + with gr.Group(elem_id=target_interface + '_subseed_show_box'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) + + # Components to show/hide based on the 'Extra' checkbox + seed_extras = [] + + with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: + seed_extras.append(seed_extra_row_1) + subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') + subseed.style(container=False) + random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') + + with FormRow(visible=False) as seed_extra_row_2: + seed_extras.append(seed_extra_row_2) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') + + random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) + random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) + + def change_visibility(show): + return {comp: gr_show(show) for comp in seed_extras} + + seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) + + return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox + + + +def connect_clear_prompt(button): + """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" + button.click( + _js="clear_prompt", + fn=None, + inputs=[], + outputs=[], + ) + + +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): + """ Connects a 'reuse (sub)seed' button's click event so that it copies last used + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength + was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" + def copy_seed(gen_info_string: str, index): + res = -1 + + try: + gen_info = json.loads(gen_info_string) + index -= gen_info.get('index_of_first_image', 0) + + if is_subseed and gen_info.get('subseed_strength', 0) > 0: + all_subseeds = gen_info.get('all_subseeds', [-1]) + res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] + else: + all_seeds = gen_info.get('all_seeds', [-1]) + res = all_seeds[index if 0 <= index < len(all_seeds) else 0] + + except json.decoder.JSONDecodeError as e: + if gen_info_string != '': + print("Error parsing JSON generation info:", file=sys.stderr) + print(gen_info_string, file=sys.stderr) + + return [res, gr_show(False)] + + reuse_seed.click( + fn=copy_seed, + _js="(x, y) => [x, selected_gallery_index()]", + show_progress=False, + inputs=[generation_info, dummy_component], + outputs=[seed, dummy_component] + ) + + +def update_token_counter(text, steps): + try: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" + + +def create_toprow(is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + + with gr.Row(elem_id="toprow"): + with gr.Column(scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Column(scale=1, elem_id="roll_col"): + paste = gr.Button(value=paste_symbol, elem_id="paste") + save_style = gr.Button(value=save_style_symbol, elem_id="style_create") + prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + + clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[prompt, negative_prompt], + outputs=[prompt, negative_prompt], + ) + + button_interrogate = None + button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_id="interrogate_col"): + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + + with gr.Column(scale=1): + with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + with gr.Row(): + with gr.Column(scale=1, elem_id="style_pos_col"): + prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + with gr.Column(scale=1, elem_id="style_neg_col"): + prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + + +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) + check_progress.click( + fn=lambda: check_progress_call(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) + check_progress_initial.click( + fn=lambda: check_progress_call_initial(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + +def apply_setting(key, value): + if value is None: + return gr.update() + + if shared.cmd_opts.freeze_settings: + return gr.update() + + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data.get(key, None) + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return value + + +def update_generation_info(args): + generation_info, html_info, img_index = args + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info + return plaintext_to_html(generation_info["infotexts"][img_index]) + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_output_panel(tabname, outdir): + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel'): + with gr.Group(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="(x, y) => [x, y, selected_gallery_index()]", + inputs=[generation_info, html_info], + outputs=[html_info], + preprocess=False + ) + + save.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + save_zip.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + + +def create_sampler_and_steps_selection(choices, tabname): + if opts.samplers_in_dropdown: + with FormRow(elem_id=f"sampler_selection_{tabname}"): + sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + else: + with FormGroup(elem_id=f"sampler_selection_{tabname}"): + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + + return steps, sampler_index + + +def ordered_ui_categories(): + user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} + + for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): + yield category + + +def create_ui(): + import modules.img2img + import modules.txt2img + + reload_javascript() + + parameters_copypaste.reset() + + modules.scripts.scripts_current = modules.scripts.scripts_txt2img + modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + + with gr.Blocks(analytics_enabled=False) as txt2img_interface: + txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + + dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Row(elem_id='txt2img_progress_row'): + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="txt2img_progressbar") + txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) + setup_progressbar(progressbar, txt2img_preview, 'txt2img') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel', elem_id="txt2img_settings"): + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "cfg": + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') + + elif category == "checkboxes": + with FormRow(elem_id="txt2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) + + elif category == "hires_fix": + with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: + with FormRow(elem_id="txt2img_hires_fix_row1"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + with FormRow(elem_id="txt2img_hires_fix_row2"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="txt2img_script_container"): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] + for input in hr_resolution_preview_inputs: + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) + + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) + parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + txt2img_args = dict( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), + _js="submit", + inputs=[ + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + ] + custom_inputs, + + outputs=[ + txt2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + txt2img_prompt.submit(**txt2img_args) + submit.click(**txt2img_args) + + txt_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ] + ) + + enable_hr.change( + fn=lambda x: gr_show(x), + inputs=[enable_hr], + outputs=[hr_options], + show_progress = False, + ) + + txt2img_paste_fields = [ + (txt2img_prompt, "Prompt"), + (txt2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (enable_hr, lambda d: "Denoising strength" in d), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), + (hr_second_pass_steps, "Hires steps"), + (hr_resize_x, "Hires resize-1"), + (hr_resize_y, "Hires resize-2"), + *modules.scripts.scripts_txt2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) + + modules.scripts.scripts_current = modules.scripts.scripts_img2img + modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + + with gr.Blocks(analytics_enabled=False) as img2img_interface: + img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) + + with gr.Row(elem_id='img2img_progress_row'): + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="img2img_progressbar") + img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) + setup_progressbar(progressbar, img2img_preview, 'img2img') + + with FormRow().style(equal_height=False): + with gr.Column(variant='panel', elem_id="img2img_settings"): + + with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) + init_img_with_mask_orig = gr.State(None) + + use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" + if use_color_sketch: + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) + + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") + + with FormRow(): + mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "cfg": + with FormGroup(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + + elif category == "checkboxes": + with FormRow(elem_id="img2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="img2img_script_container"): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() + + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + img2img_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + img2img_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ] + ) + + mask_mode.change( + lambda mode, img: { + init_img_with_mask: gr_show(mode == 0), + init_img_inpaint: gr_show(mode == 1), + init_mask_inpaint: gr_show(mode == 1), + }, + inputs=[mask_mode, init_img_with_mask], + outputs=[ + init_img_with_mask, + init_img_inpaint, + init_mask_inpaint, + ], + ) + + img2img_args = dict( + fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), + _js="submit_img2img", + inputs=[ + dummy_component, + img2img_prompt, + img2img_negative_prompt, + img2img_prompt_style, + img2img_prompt_style2, + init_img, + init_img_with_mask, + init_img_with_mask_orig, + init_img_inpaint, + init_mask_inpaint, + mask_mode, + steps, + sampler_index, + mask_blur, + mask_alpha, + inpainting_fill, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + denoising_strength, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + resize_mode, + inpaint_full_res, + inpaint_full_res_padding, + inpainting_mask_invert, + img2img_batch_input_dir, + img2img_batch_output_dir, + ] + custom_inputs, + outputs=[ + img2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + img2img_prompt.submit(**img2img_args) + submit.click(**img2img_args) + + img2img_interrogate.click( + fn=interrogate, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] + style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] + + for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): + button.click( + fn=add_style, + _js="ask_for_style_name", + # Have to pass empty dummy component here, because the JavaScript and Python function have to accept + # the same number of parameters, but we only know the style-name after the JavaScript prompt + inputs=[dummy_component, prompt, negative_prompt], + outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], + ) + + for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + button.click( + fn=apply_styles, + _js=js_func, + inputs=[prompt, negative_prompt, style1, style2], + outputs=[prompt, negative_prompt, style1, style2], + ) + + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + + img2img_paste_fields = [ + (img2img_prompt, "Prompt"), + (img2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (mask_blur, "Mask blur"), + *modules.scripts.scripts_img2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + + modules.scripts.scripts_current = None + + with gr.Blocks(analytics_enabled=False) as extras_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab"): + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): + with gr.Group(): + with gr.Row(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with gr.Group(): + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + + with gr.Group(): + extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") + + with gr.Group(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") + + with gr.Group(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") + + with gr.Group(): + upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") + + result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) + + submit.click( + fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), + _js="get_extras_tab_index", + inputs=[ + dummy_component, + dummy_component, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + gfpgan_visibility, + codeformer_visibility, + codeformer_weight, + upscaling_resize, + upscaling_resize_w, + upscaling_resize_h, + upscaling_crop, + extras_upscaler_1, + extras_upscaler_2, + extras_upscaler_2_visibility, + upscale_before_face_fix, + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=modules.extras.clear_cache, + inputs=[], outputs=[] + ) + + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") + + with gr.Column(variant='panel'): + html = gr.HTML() + generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") + html2 = gr.HTML() + with gr.Row(): + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + parameters_copypaste.bind_buttons(buttons, image, generation_info) + + image.change( + fn=wrap_gradio_call(modules.extras.run_pnginfo), + inputs=[image], + outputs=[html, generation_info, html2], + ) + + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + + with gr.Row(): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") + + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") + + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") + + custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + + with gr.Row(): + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") + + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + + with gr.Column(variant='panel'): + submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + + with gr.Blocks(analytics_enabled=False) as train_interface: + with gr.Row().style(equal_height=False): + gr.HTML(value="

See wiki for detailed explanation.

") + + with gr.Row().style(equal_height=False): + with gr.Tabs(elem_id="train_tabs"): + + with gr.Tab(label="Create embedding"): + new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") + initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") + + with gr.Tab(label="Create hypernetwork"): + new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") + + with gr.Tab(label="Preprocess images"): + process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") + process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") + process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") + process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") + + with gr.Row(): + process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") + process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") + process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") + + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") + + with gr.Row(visible=False) as process_focal_crop_row: + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") + process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") + + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + + process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[process_focal_crop], + outputs=[process_focal_crop_row], + ) + + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + + with gr.Tab(label="Train"): + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") + with FormRow(): + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") + + with FormRow(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + with FormRow(): + clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) + + with FormRow(): + batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") + training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") + steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") + + with FormRow(): + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") + + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") + + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") + + with gr.Row(): + train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") + interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") + + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + create_hypernetwork.click( + fn=modules.hypernetworks.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + overwrite_old_hypernetwork, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout, + new_hypernetwork_dropout_structure + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_width, + process_height, + preprocess_txt_action, + process_flip, + process_split, + process_caption, + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + embedding_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + save_image_with_stored_embedding, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + hypernetwork_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {str(t)} for key {key}') + + elem_id = "setting_"+key + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + with FormRow(): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + components = [] + component_dict = {} + + script_callbacks.ui_settings_callback() + opts.reorder() + + def run_settings(*args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return gr.update(value=value), opts.dumpjson() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(scale=6): + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + result = gr.HTML(elem_id="settings_result") + + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} + + quicksettings_list = [] + + previous_section = None + current_tab = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_tab.__exit__() + + current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) + current_tab.__enter__() + + previous_section = item.section + + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + + if current_tab is not None: + current_tab.__exit__() + + with gr.TabItem("Actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + + if os.path.exists("html/licenses.html"): + with open("html/licenses.html", encoding="utf8") as file: + with gr.TabItem("Licenses"): + gr.HTML(file.read(), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + def request_restart(): + shared.state.interrupt() + shared.state.need_restart = True + + restart_gradio.click( + fn=request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (train_interface, "Train", "ti"), + ] + + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" + + if os.path.exists(os.path.join(script_path, "user.css")): + with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + css += file.read() + "\n" + + if not cmd_opts.no_progressbar_hiding: + css += css_hide_progressbar + + interfaces += script_callbacks.ui_tabs_callback() + interfaces += [(settings_interface, "Settings", "settings")] + + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Row(elem_id="quicksettings"): + for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): + component = create_setting_component(k, is_quicksettings=True) + component_dict[k] = component + + parameters_copypaste.integrate_settings_paste_fields(component_dict) + parameters_copypaste.run_bind() + + with gr.Tabs(elem_id="tabs") as tabs: + for interface, label, ifid in interfaces: + with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): + interface.render() + + if os.path.exists(os.path.join(script_path, "notification.mp3")): + audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) + + if os.path.exists("html/footer.html"): + with open("html/footer.html", encoding="utf8") as file: + footer = file.read() + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") + + text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + settings_submit.click( + fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), + inputs=components, + outputs=[text_settings, result], + ) + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + def get_settings_values(): + return [getattr(opts, key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + ) + + def modelmerger(*args): + try: + results = modules.extras.run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() # to remove the potentially missing models from the list + return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] + return results + + modelmerger_merge.click( + fn=modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + custom_name, + checkpoint_format, + ], + outputs=[ + submit_result, + primary_model_name, + secondary_model_name, + tertiary_model_name, + component_dict['sd_model_checkpoint'], + ] + ) + + ui_config_file = cmd_opts.ui_config_file + ui_settings = {} + settings_count = len(ui_settings) + error_loading = False + + try: + if os.path.exists(ui_config_file): + with open(ui_config_file, "r", encoding="utf8") as file: + ui_settings = json.load(file) + except Exception: + error_loading = True + print("Error loading settings:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def loadsave(path, x): + def apply_field(obj, field, condition=None, init_field=None): + key = path + "/" + field + + if getattr(obj, 'custom_script_source', None) is not None: + key = 'customscript/' + obj.custom_script_source + '/' + key + + if getattr(obj, 'do_not_save_to_config', False): + return + + saved_value = ui_settings.get(key, None) + if saved_value is None: + ui_settings[key] = getattr(obj, field) + elif condition and not condition(saved_value): + print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + else: + setattr(obj, field, saved_value) + if init_field is not None: + init_field(saved_value) + + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: + apply_field(x, 'visible') + + if type(x) == gr.Slider: + apply_field(x, 'value') + apply_field(x, 'minimum') + apply_field(x, 'maximum') + apply_field(x, 'step') + + if type(x) == gr.Radio: + apply_field(x, 'value', lambda val: val in x.choices) + + if type(x) == gr.Checkbox: + apply_field(x, 'value') + + if type(x) == gr.Textbox: + apply_field(x, 'value') + + if type(x) == gr.Number: + apply_field(x, 'value') + + if type(x) == gr.Dropdown: + apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) + + visit(txt2img_interface, loadsave, "txt2img") + visit(img2img_interface, loadsave, "img2img") + visit(extras_interface, loadsave, "extras") + visit(modelmerger_interface, loadsave, "modelmerger") + visit(train_interface, loadsave, "train") + + if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): + with open(ui_config_file, "w", encoding="utf8") as file: + json.dump(ui_settings, file, indent=4) + + return demo + + +def reload_javascript(): + with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: + javascript = f'' + + scripts_list = modules.scripts.list_scripts("javascript", ".js") + + for basedir, filename, path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: + javascript += f"\n" + + if cmd_opts.theme is not None: + javascript += f"\n\n" + + javascript += f"\n" + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gradio.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse + + +def versions_html(): + import torch + import launch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = launch.commit_hash() + short_commit = commit[0:8] + + if shared.xformers_available: + import xformers + xformers_version = xformers.__version__ + else: + xformers_version = "N/A" + + return f""" +python: {python_version} + •  +torch: {torch.__version__} + •  +xformers: {xformers_version} + •  +gradio: {gr.__version__} + •  +commit: {short_commit} +""" -- cgit v1.2.3 From 27ea6949d3206c9a52fa77db587bac0012cb0b52 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Tue, 10 Jan 2023 11:54:48 +0300 Subject: Split history ui.py to ui_progress.py --- modules/temp | 1928 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ modules/ui.py | 1928 --------------------------------------------------------- 2 files changed, 1928 insertions(+), 1928 deletions(-) create mode 100644 modules/temp delete mode 100644 modules/ui.py (limited to 'modules') diff --git a/modules/temp b/modules/temp new file mode 100644 index 00000000..9b9081b5 --- /dev/null +++ b/modules/temp @@ -0,0 +1,1928 @@ +import html +import json +import math +import mimetypes +import os +import platform +import random +import subprocess as sp +import sys +import tempfile +import time +import traceback +from functools import partial, reduce + +import gradio as gr +import gradio.routes +import gradio.utils +import numpy as np +from PIL import Image, PngImagePlugin +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call + +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML +from modules.paths import script_path + +from modules.shared import opts, cmd_opts, restricted_opts + +import modules.codeformer_model +import modules.generation_parameters_copypaste as parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.scripts +import modules.shared as shared +import modules.styles +import modules.textual_inversion.ui +from modules import prompt_parser +from modules.images import save_image +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img +from modules.textual_inversion import textual_inversion +import modules.hypernetworks.ui +from modules.generation_parameters_copypaste import image_from_url_text + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +if not cmd_opts.share and not cmd_opts.listen: + # fix gradio phoning home + gradio.utils.version_check = lambda: None + gradio.utils.get_local_ip_address = lambda: '127.0.0.1' + +if cmd_opts.ngrok is not None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect( + cmd_opts.ngrok, + cmd_opts.port if cmd_opts.port is not None else 7860, + cmd_opts.ngrok_region + ) + + +def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + +sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None + +css_hide_progressbar = """ +.wrap .m-12 svg { display:none!important; } +.wrap .m-12::before { content:"Loading..." } +.wrap .z-20 svg { display:none!important; } +.wrap .z-20::before { content:"Loading..." } +.progress-bar { display:none!important; } +.meta-text { display:none!important; } +.meta-text-center { display:none!important; } +""" + +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +random_symbol = '\U0001f3b2\ufe0f' # 🎲️ +reuse_symbol = '\u267b\ufe0f' # ♻️ +paste_symbol = '\u2199\ufe0f' # ↙ +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 +clear_prompt_symbol = '\U0001F5D1' # 🗑️ + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + return image_from_url_text(x[0]) + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = opts.outdir_save + save_to_dirs = opts.use_save_to_dirs_for_ui + extension: str = opts.samples_format + start_index = 0 + + if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(opts.outdir_save, exist_ok=True) + + with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def calc_time_left(progress, threshold, label, force_display, show_eta): + if progress == 0: + return "" + else: + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + eta_relative = eta-time_since_start + if (eta_relative > threshold and show_eta) or force_display: + if eta_relative > 3600: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + elif eta_relative > 60: + return label + time.strftime('%M:%S', time.gmtime(eta_relative)) + else: + return label + time.strftime('%Ss', time.gmtime(eta_relative)) + else: + return "" + + +def check_progress_call(id_part): + if shared.state.job_count == 0: + return "", gr_show(False), gr_show(False), gr_show(False) + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + # Show progress percentage and time left at the same moment, and base it also on steps done + show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 + + time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) + if time_left != "": + shared.state.time_left_force_display = True + + progress = min(progress, 1) + + progressbar = "" + if opts.show_progressbar: + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" + + image = gr_show(False) + preview_visibility = gr_show(False) + + if opts.show_progress_every_n_steps != 0: + shared.state.set_current_image() + image = shared.state.current_image + + if image is None: + image = gr.update(value=None) + else: + preview_visibility = gr_show(True) + + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result + + +def check_progress_call_initial(id_part): + shared.state.job_count = -1 + shared.state.current_latent = None + shared.state.current_image = None + shared.state.textinfo = None + shared.state.time_start = time.time() + shared.state.time_left_force_display = False + + return check_progress_call(id_part) + + +def visit(x, func, path=""): + if hasattr(x, 'children'): + for c in x.children: + visit(c, func, path) + elif x.label is not None: + func(path + "/" + str(x.label), x) + + +def add_style(name: str, prompt: str, negative_prompt: str): + if name is None: + return [gr_show() for x in range(4)] + + style = modules.styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we + # reserialize all styles every time we save them + shared.prompt_styles.save_styles(shared.styles_filename) + + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] + + +def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): + from modules import processing, devices + + if not enable: + return "" + + p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + + with devices.autocast(): + p.init([""], [0], [0]) + + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" + + +def apply_styles(prompt, prompt_neg, style1_name, style2_name): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] + + +def interrogate(image): + prompt = shared.interrogator.interrogate(image.convert("RGB")) + + return gr_show(True) if prompt is None else prompt + + +def interrogate_deepbooru(image): + prompt = deepbooru.model.tag(image) + return gr_show(True) if prompt is None else prompt + + +def create_seed_inputs(target_interface): + with FormRow(elem_id=target_interface + '_seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') + seed.style(container=False) + random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') + + with gr.Group(elem_id=target_interface + '_subseed_show_box'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) + + # Components to show/hide based on the 'Extra' checkbox + seed_extras = [] + + with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: + seed_extras.append(seed_extra_row_1) + subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') + subseed.style(container=False) + random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') + + with FormRow(visible=False) as seed_extra_row_2: + seed_extras.append(seed_extra_row_2) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') + + random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) + random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) + + def change_visibility(show): + return {comp: gr_show(show) for comp in seed_extras} + + seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) + + return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox + + + +def connect_clear_prompt(button): + """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" + button.click( + _js="clear_prompt", + fn=None, + inputs=[], + outputs=[], + ) + + +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): + """ Connects a 'reuse (sub)seed' button's click event so that it copies last used + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength + was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" + def copy_seed(gen_info_string: str, index): + res = -1 + + try: + gen_info = json.loads(gen_info_string) + index -= gen_info.get('index_of_first_image', 0) + + if is_subseed and gen_info.get('subseed_strength', 0) > 0: + all_subseeds = gen_info.get('all_subseeds', [-1]) + res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] + else: + all_seeds = gen_info.get('all_seeds', [-1]) + res = all_seeds[index if 0 <= index < len(all_seeds) else 0] + + except json.decoder.JSONDecodeError as e: + if gen_info_string != '': + print("Error parsing JSON generation info:", file=sys.stderr) + print(gen_info_string, file=sys.stderr) + + return [res, gr_show(False)] + + reuse_seed.click( + fn=copy_seed, + _js="(x, y) => [x, selected_gallery_index()]", + show_progress=False, + inputs=[generation_info, dummy_component], + outputs=[seed, dummy_component] + ) + + +def update_token_counter(text, steps): + try: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" + + +def create_toprow(is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + + with gr.Row(elem_id="toprow"): + with gr.Column(scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Column(scale=1, elem_id="roll_col"): + paste = gr.Button(value=paste_symbol, elem_id="paste") + save_style = gr.Button(value=save_style_symbol, elem_id="style_create") + prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + + clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[prompt, negative_prompt], + outputs=[prompt, negative_prompt], + ) + + button_interrogate = None + button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_id="interrogate_col"): + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + + with gr.Column(scale=1): + with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + with gr.Row(): + with gr.Column(scale=1, elem_id="style_pos_col"): + prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + with gr.Column(scale=1, elem_id="style_neg_col"): + prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + + +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) + check_progress.click( + fn=lambda: check_progress_call(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) + check_progress_initial.click( + fn=lambda: check_progress_call_initial(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + +def apply_setting(key, value): + if value is None: + return gr.update() + + if shared.cmd_opts.freeze_settings: + return gr.update() + + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data.get(key, None) + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return value + + +def update_generation_info(args): + generation_info, html_info, img_index = args + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info + return plaintext_to_html(generation_info["infotexts"][img_index]) + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_output_panel(tabname, outdir): + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel'): + with gr.Group(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="(x, y) => [x, y, selected_gallery_index()]", + inputs=[generation_info, html_info], + outputs=[html_info], + preprocess=False + ) + + save.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + save_zip.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + + +def create_sampler_and_steps_selection(choices, tabname): + if opts.samplers_in_dropdown: + with FormRow(elem_id=f"sampler_selection_{tabname}"): + sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + else: + with FormGroup(elem_id=f"sampler_selection_{tabname}"): + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + + return steps, sampler_index + + +def ordered_ui_categories(): + user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} + + for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): + yield category + + +def create_ui(): + import modules.img2img + import modules.txt2img + + reload_javascript() + + parameters_copypaste.reset() + + modules.scripts.scripts_current = modules.scripts.scripts_txt2img + modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + + with gr.Blocks(analytics_enabled=False) as txt2img_interface: + txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + + dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Row(elem_id='txt2img_progress_row'): + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="txt2img_progressbar") + txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) + setup_progressbar(progressbar, txt2img_preview, 'txt2img') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel', elem_id="txt2img_settings"): + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "cfg": + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') + + elif category == "checkboxes": + with FormRow(elem_id="txt2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) + + elif category == "hires_fix": + with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: + with FormRow(elem_id="txt2img_hires_fix_row1"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + with FormRow(elem_id="txt2img_hires_fix_row2"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="txt2img_script_container"): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] + for input in hr_resolution_preview_inputs: + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) + + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) + parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + txt2img_args = dict( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), + _js="submit", + inputs=[ + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + ] + custom_inputs, + + outputs=[ + txt2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + txt2img_prompt.submit(**txt2img_args) + submit.click(**txt2img_args) + + txt_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ] + ) + + enable_hr.change( + fn=lambda x: gr_show(x), + inputs=[enable_hr], + outputs=[hr_options], + show_progress = False, + ) + + txt2img_paste_fields = [ + (txt2img_prompt, "Prompt"), + (txt2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (enable_hr, lambda d: "Denoising strength" in d), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), + (hr_second_pass_steps, "Hires steps"), + (hr_resize_x, "Hires resize-1"), + (hr_resize_y, "Hires resize-2"), + *modules.scripts.scripts_txt2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) + + modules.scripts.scripts_current = modules.scripts.scripts_img2img + modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + + with gr.Blocks(analytics_enabled=False) as img2img_interface: + img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) + + with gr.Row(elem_id='img2img_progress_row'): + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="img2img_progressbar") + img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) + setup_progressbar(progressbar, img2img_preview, 'img2img') + + with FormRow().style(equal_height=False): + with gr.Column(variant='panel', elem_id="img2img_settings"): + + with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) + init_img_with_mask_orig = gr.State(None) + + use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" + if use_color_sketch: + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) + + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") + + with FormRow(): + mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "cfg": + with FormGroup(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + + elif category == "checkboxes": + with FormRow(elem_id="img2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="img2img_script_container"): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() + + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + img2img_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + img2img_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ] + ) + + mask_mode.change( + lambda mode, img: { + init_img_with_mask: gr_show(mode == 0), + init_img_inpaint: gr_show(mode == 1), + init_mask_inpaint: gr_show(mode == 1), + }, + inputs=[mask_mode, init_img_with_mask], + outputs=[ + init_img_with_mask, + init_img_inpaint, + init_mask_inpaint, + ], + ) + + img2img_args = dict( + fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), + _js="submit_img2img", + inputs=[ + dummy_component, + img2img_prompt, + img2img_negative_prompt, + img2img_prompt_style, + img2img_prompt_style2, + init_img, + init_img_with_mask, + init_img_with_mask_orig, + init_img_inpaint, + init_mask_inpaint, + mask_mode, + steps, + sampler_index, + mask_blur, + mask_alpha, + inpainting_fill, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + denoising_strength, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + resize_mode, + inpaint_full_res, + inpaint_full_res_padding, + inpainting_mask_invert, + img2img_batch_input_dir, + img2img_batch_output_dir, + ] + custom_inputs, + outputs=[ + img2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + img2img_prompt.submit(**img2img_args) + submit.click(**img2img_args) + + img2img_interrogate.click( + fn=interrogate, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] + style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] + + for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): + button.click( + fn=add_style, + _js="ask_for_style_name", + # Have to pass empty dummy component here, because the JavaScript and Python function have to accept + # the same number of parameters, but we only know the style-name after the JavaScript prompt + inputs=[dummy_component, prompt, negative_prompt], + outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], + ) + + for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + button.click( + fn=apply_styles, + _js=js_func, + inputs=[prompt, negative_prompt, style1, style2], + outputs=[prompt, negative_prompt, style1, style2], + ) + + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + + img2img_paste_fields = [ + (img2img_prompt, "Prompt"), + (img2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (mask_blur, "Mask blur"), + *modules.scripts.scripts_img2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + + modules.scripts.scripts_current = None + + with gr.Blocks(analytics_enabled=False) as extras_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab"): + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): + with gr.Group(): + with gr.Row(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with gr.Group(): + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + + with gr.Group(): + extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") + + with gr.Group(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") + + with gr.Group(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") + + with gr.Group(): + upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") + + result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) + + submit.click( + fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), + _js="get_extras_tab_index", + inputs=[ + dummy_component, + dummy_component, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + gfpgan_visibility, + codeformer_visibility, + codeformer_weight, + upscaling_resize, + upscaling_resize_w, + upscaling_resize_h, + upscaling_crop, + extras_upscaler_1, + extras_upscaler_2, + extras_upscaler_2_visibility, + upscale_before_face_fix, + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=modules.extras.clear_cache, + inputs=[], outputs=[] + ) + + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") + + with gr.Column(variant='panel'): + html = gr.HTML() + generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") + html2 = gr.HTML() + with gr.Row(): + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + parameters_copypaste.bind_buttons(buttons, image, generation_info) + + image.change( + fn=wrap_gradio_call(modules.extras.run_pnginfo), + inputs=[image], + outputs=[html, generation_info, html2], + ) + + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + + with gr.Row(): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") + + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") + + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") + + custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + + with gr.Row(): + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") + + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + + with gr.Column(variant='panel'): + submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + + with gr.Blocks(analytics_enabled=False) as train_interface: + with gr.Row().style(equal_height=False): + gr.HTML(value="

See wiki for detailed explanation.

") + + with gr.Row().style(equal_height=False): + with gr.Tabs(elem_id="train_tabs"): + + with gr.Tab(label="Create embedding"): + new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") + initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") + + with gr.Tab(label="Create hypernetwork"): + new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") + + with gr.Tab(label="Preprocess images"): + process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") + process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") + process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") + process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") + + with gr.Row(): + process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") + process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") + process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") + + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") + + with gr.Row(visible=False) as process_focal_crop_row: + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") + process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") + + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + + process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[process_focal_crop], + outputs=[process_focal_crop_row], + ) + + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + + with gr.Tab(label="Train"): + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") + with FormRow(): + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") + + with FormRow(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + with FormRow(): + clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) + + with FormRow(): + batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") + training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") + steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") + + with FormRow(): + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") + + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") + + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") + + with gr.Row(): + train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") + interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") + + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + create_hypernetwork.click( + fn=modules.hypernetworks.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + overwrite_old_hypernetwork, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout, + new_hypernetwork_dropout_structure + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_width, + process_height, + preprocess_txt_action, + process_flip, + process_split, + process_caption, + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + embedding_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + save_image_with_stored_embedding, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + hypernetwork_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {str(t)} for key {key}') + + elem_id = "setting_"+key + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + with FormRow(): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + components = [] + component_dict = {} + + script_callbacks.ui_settings_callback() + opts.reorder() + + def run_settings(*args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return gr.update(value=value), opts.dumpjson() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(scale=6): + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + result = gr.HTML(elem_id="settings_result") + + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} + + quicksettings_list = [] + + previous_section = None + current_tab = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_tab.__exit__() + + current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) + current_tab.__enter__() + + previous_section = item.section + + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + + if current_tab is not None: + current_tab.__exit__() + + with gr.TabItem("Actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + + if os.path.exists("html/licenses.html"): + with open("html/licenses.html", encoding="utf8") as file: + with gr.TabItem("Licenses"): + gr.HTML(file.read(), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + def request_restart(): + shared.state.interrupt() + shared.state.need_restart = True + + restart_gradio.click( + fn=request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (train_interface, "Train", "ti"), + ] + + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" + + if os.path.exists(os.path.join(script_path, "user.css")): + with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + css += file.read() + "\n" + + if not cmd_opts.no_progressbar_hiding: + css += css_hide_progressbar + + interfaces += script_callbacks.ui_tabs_callback() + interfaces += [(settings_interface, "Settings", "settings")] + + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Row(elem_id="quicksettings"): + for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): + component = create_setting_component(k, is_quicksettings=True) + component_dict[k] = component + + parameters_copypaste.integrate_settings_paste_fields(component_dict) + parameters_copypaste.run_bind() + + with gr.Tabs(elem_id="tabs") as tabs: + for interface, label, ifid in interfaces: + with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): + interface.render() + + if os.path.exists(os.path.join(script_path, "notification.mp3")): + audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) + + if os.path.exists("html/footer.html"): + with open("html/footer.html", encoding="utf8") as file: + footer = file.read() + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") + + text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + settings_submit.click( + fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), + inputs=components, + outputs=[text_settings, result], + ) + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + def get_settings_values(): + return [getattr(opts, key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + ) + + def modelmerger(*args): + try: + results = modules.extras.run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() # to remove the potentially missing models from the list + return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] + return results + + modelmerger_merge.click( + fn=modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + custom_name, + checkpoint_format, + ], + outputs=[ + submit_result, + primary_model_name, + secondary_model_name, + tertiary_model_name, + component_dict['sd_model_checkpoint'], + ] + ) + + ui_config_file = cmd_opts.ui_config_file + ui_settings = {} + settings_count = len(ui_settings) + error_loading = False + + try: + if os.path.exists(ui_config_file): + with open(ui_config_file, "r", encoding="utf8") as file: + ui_settings = json.load(file) + except Exception: + error_loading = True + print("Error loading settings:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def loadsave(path, x): + def apply_field(obj, field, condition=None, init_field=None): + key = path + "/" + field + + if getattr(obj, 'custom_script_source', None) is not None: + key = 'customscript/' + obj.custom_script_source + '/' + key + + if getattr(obj, 'do_not_save_to_config', False): + return + + saved_value = ui_settings.get(key, None) + if saved_value is None: + ui_settings[key] = getattr(obj, field) + elif condition and not condition(saved_value): + print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + else: + setattr(obj, field, saved_value) + if init_field is not None: + init_field(saved_value) + + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: + apply_field(x, 'visible') + + if type(x) == gr.Slider: + apply_field(x, 'value') + apply_field(x, 'minimum') + apply_field(x, 'maximum') + apply_field(x, 'step') + + if type(x) == gr.Radio: + apply_field(x, 'value', lambda val: val in x.choices) + + if type(x) == gr.Checkbox: + apply_field(x, 'value') + + if type(x) == gr.Textbox: + apply_field(x, 'value') + + if type(x) == gr.Number: + apply_field(x, 'value') + + if type(x) == gr.Dropdown: + apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) + + visit(txt2img_interface, loadsave, "txt2img") + visit(img2img_interface, loadsave, "img2img") + visit(extras_interface, loadsave, "extras") + visit(modelmerger_interface, loadsave, "modelmerger") + visit(train_interface, loadsave, "train") + + if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): + with open(ui_config_file, "w", encoding="utf8") as file: + json.dump(ui_settings, file, indent=4) + + return demo + + +def reload_javascript(): + with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: + javascript = f'' + + scripts_list = modules.scripts.list_scripts("javascript", ".js") + + for basedir, filename, path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: + javascript += f"\n" + + if cmd_opts.theme is not None: + javascript += f"\n\n" + + javascript += f"\n" + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gradio.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse + + +def versions_html(): + import torch + import launch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = launch.commit_hash() + short_commit = commit[0:8] + + if shared.xformers_available: + import xformers + xformers_version = xformers.__version__ + else: + xformers_version = "N/A" + + return f""" +python: {python_version} + •  +torch: {torch.__version__} + •  +xformers: {xformers_version} + •  +gradio: {gr.__version__} + •  +commit: {short_commit} +""" diff --git a/modules/ui.py b/modules/ui.py deleted file mode 100644 index 9b9081b5..00000000 --- a/modules/ui.py +++ /dev/null @@ -1,1928 +0,0 @@ -import html -import json -import math -import mimetypes -import os -import platform -import random -import subprocess as sp -import sys -import tempfile -import time -import traceback -from functools import partial, reduce - -import gradio as gr -import gradio.routes -import gradio.utils -import numpy as np -from PIL import Image, PngImagePlugin -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call - -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru -from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path - -from modules.shared import opts, cmd_opts, restricted_opts - -import modules.codeformer_model -import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.scripts -import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui -from modules import prompt_parser -from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.textual_inversion import textual_inversion -import modules.hypernetworks.ui -from modules.generation_parameters_copypaste import image_from_url_text - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -if not cmd_opts.share and not cmd_opts.listen: - # fix gradio phoning home - gradio.utils.version_check = lambda: None - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' - -if cmd_opts.ngrok is not None: - import modules.ngrok as ngrok - print('ngrok authtoken detected, trying to connect...') - ngrok.connect( - cmd_opts.ngrok, - cmd_opts.port if cmd_opts.port is not None else 7860, - cmd_opts.ngrok_region - ) - - -def gr_show(visible=True): - return {"visible": visible, "__type__": "update"} - - -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None - -css_hide_progressbar = """ -.wrap .m-12 svg { display:none!important; } -.wrap .m-12::before { content:"Loading..." } -.wrap .z-20 svg { display:none!important; } -.wrap .z-20::before { content:"Loading..." } -.progress-bar { display:none!important; } -.meta-text { display:none!important; } -.meta-text-center { display:none!important; } -""" - -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -random_symbol = '\U0001f3b2\ufe0f' # 🎲️ -reuse_symbol = '\u267b\ufe0f' # ♻️ -paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ - - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text - -def send_gradio_gallery_to_image(x): - if len(x) == 0: - return None - return image_from_url_text(x[0]) - -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr_show(False) - preview_visibility = gr_show(False) - - if opts.show_progress_every_n_steps != 0: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr_show(True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr_show(False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def visit(x, func, path=""): - if hasattr(x, 'children'): - for c in x.children: - visit(c, func, path) - elif x.label is not None: - func(path + "/" + str(x.label), x) - - -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] - - -def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): - from modules import processing, devices - - if not enable: - return "" - - p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) - - with devices.autocast(): - p.init([""], [0], [0]) - - return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" - - -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] - - -def interrogate(image): - prompt = shared.interrogator.interrogate(image.convert("RGB")) - - return gr_show(True) if prompt is None else prompt - - -def interrogate_deepbooru(image): - prompt = deepbooru.model.tag(image) - return gr_show(True) if prompt is None else prompt - - -def create_seed_inputs(target_interface): - with FormRow(elem_id=target_interface + '_seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') - seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') - - with gr.Group(elem_id=target_interface + '_subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) - - # Components to show/hide based on the 'Extra' checkbox - seed_extras = [] - - with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: - seed_extras.append(seed_extra_row_1) - subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') - subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') - - with FormRow(visible=False) as seed_extra_row_2: - seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') - - random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) - random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) - - def change_visibility(show): - return {comp: gr_show(show) for comp in seed_extras} - - seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) - - return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - - - -def connect_clear_prompt(button): - """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" - button.click( - _js="clear_prompt", - fn=None, - inputs=[], - outputs=[], - ) - - -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): - """ Connects a 'reuse (sub)seed' button's click event so that it copies last used - (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength - was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" - def copy_seed(gen_info_string: str, index): - res = -1 - - try: - gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError as e: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) - - return [res, gr_show(False)] - - reuse_seed.click( - fn=copy_seed, - _js="(x, y) => [x, selected_gallery_index()]", - show_progress=False, - inputs=[generation_info, dummy_component], - outputs=[seed, dummy_component] - ) - - -def update_token_counter(text, steps): - try: - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] - - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" - - -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - - with gr.Row(elem_id="toprow"): - with gr.Column(scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, - placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, - placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_id="interrogate_col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1): - with gr.Row(): - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - -def apply_setting(key, value): - if value is None: - return gr.update() - - if shared.cmd_opts.freeze_settings: - return gr.update() - - # dont allow model to be swapped when model hash exists in prompt - if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: - return gr.update() - - if key == "sd_model_checkpoint": - ckpt_info = sd_models.get_closet_checkpoint_match(value) - - if ckpt_info is not None: - value = ckpt_info.title - else: - return gr.update() - - comp_args = opts.data_labels[key].component_args - if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: - return - - valtype = type(opts.data_labels[key].default) - oldval = opts.data.get(key, None) - opts.data[key] = valtype(value) if valtype != type(None) else value - if oldval != value and opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() - - opts.save(shared.config_filename) - return value - - -def update_generation_info(args): - generation_info, html_info, img_index = args - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info - - -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - -def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel'): - with gr.Group(): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log - - -def create_sampler_and_steps_selection(choices, tabname): - if opts.samplers_in_dropdown: - with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - else: - with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - - return steps, sampler_index - - -def ordered_ui_categories(): - user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} - - for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): - yield category - - -def create_ui(): - import modules.img2img - import modules.txt2img - - reload_javascript() - - parameters_copypaste.reset() - - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) - - with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) - - dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel', elem_id="txt2img_settings"): - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "cfg": - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - - elif category == "checkboxes": - with FormRow(elem_id="txt2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) - - elif category == "hires_fix": - with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: - with FormRow(elem_id="txt2img_hires_fix_row1"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - - with FormRow(elem_id="txt2img_hires_fix_row2"): - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - - hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - for input in hr_resolution_preview_inputs: - input.change( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False, - ) - input.change( - None, - _js="onCalcResolutionHires", - inputs=hr_resolution_preview_inputs, - outputs=[], - show_progress=False, - ) - - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - txt2img_args = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), - _js="submit", - inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) - - txt_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ] - ) - - enable_hr.change( - fn=lambda x: gr_show(x), - inputs=[enable_hr], - outputs=[hr_options], - show_progress = False, - ) - - txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - *modules.scripts.scripts_txt2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) - - txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, - steps, - sampler_index, - cfg_scale, - seed, - width, - height, - ] - - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) - - with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') - - with FormRow().style(equal_height=False): - with gr.Column(variant='panel', elem_id="img2img_settings"): - - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) - - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) - - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") - - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "cfg": - with FormGroup(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') - - elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() - - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ] - ) - - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - - img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), - _js="submit_img2img", - inputs=[ - dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, - init_img, - init_img_with_mask, - init_img_with_mask_orig, - init_img_inpaint, - init_mask_inpaint, - mask_mode, - steps, - sampler_index, - mask_blur, - mask_alpha, - inpainting_fill, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - denoising_strength, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - ] + custom_inputs, - outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) - - img2img_interrogate.click( - fn=interrogate, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], - ) - - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - - img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (mask_blur, "Mask blur"), - *modules.scripts.scripts_img2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) - - modules.scripts.scripts_current = None - - with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=modules.extras.clear_cache, - inputs=[], outputs=[] - ) - - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") - - with gr.Column(variant='panel'): - html = gr.HTML() - generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") - html2 = gr.HTML() - with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) - - image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), - inputs=[image], - outputs=[html, generation_info, html2], - ) - - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") - - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") - - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") - - custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - - with gr.Row(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") - save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): - gr.HTML(value="

See wiki for detailed explanation.

") - - with gr.Row().style(equal_height=False): - with gr.Tabs(elem_id="train_tabs"): - - with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") - initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") - - with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") - new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - - with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - def get_textual_inversion_template_names(): - return sorted([x for x in textual_inversion.textual_inversion_templates]) - - with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with FormRow(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - - with FormRow(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - - with FormRow(): - clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) - clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) - - with FormRow(): - batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") - - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - - with FormRow(): - template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) - create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") - - training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") - training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") - steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") - - with FormRow(): - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") - - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") - - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") - - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") - - with gr.Row(): - train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") - interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") - - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) - - script_callbacks.ui_train_tabs_callback(params) - - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) - ti_progress = gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) - - create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout, - new_hypernetwork_dropout_structure - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_embedding_name, - embedding_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {str(t)} for key {key}') - - elem_id = "setting_"+key - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - with FormRow(): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - - components = [] - component_dict = {} - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return gr.update(value=value), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - with gr.Row(): - with gr.Column(scale=6): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - with gr.Column(): - restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") - - result = gr.HTML(elem_id="settings_result") - - quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} - - quicksettings_list = [] - - previous_section = None - current_tab = None - with gr.Tabs(elem_id="settings"): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - elem_id, text = item.section - - if current_tab is not None: - current_tab.__exit__() - - current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) - current_tab.__enter__() - - previous_section = item.section - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - - if current_tab is not None: - current_tab.__exit__() - - with gr.TabItem("Actions"): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") - - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - - restart_gradio.click( - fn=request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) - - interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "ti"), - ] - - css = "" - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - with open(cssfile, "r", encoding="utf8") as file: - css += file.read() + "\n" - - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - css += file.read() + "\n" - - if not cmd_opts.no_progressbar_hiding: - css += css_hide_progressbar - - interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] - - extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] - - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings"): - for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component - - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() - - with gr.Tabs(elem_id="tabs") as tabs: - for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): - interface.render() - - if os.path.exists(os.path.join(script_path, "notification.mp3")): - audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") - - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for i, k, item in quicksettings_list: - component = component_dict[k] - - component.change( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [getattr(opts, key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - ) - - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - custom_name, - checkpoint_format, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) - - ui_config_file = cmd_opts.ui_config_file - ui_settings = {} - settings_count = len(ui_settings) - error_loading = False - - try: - if os.path.exists(ui_config_file): - with open(ui_config_file, "r", encoding="utf8") as file: - ui_settings = json.load(file) - except Exception: - error_loading = True - print("Error loading settings:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - def loadsave(path, x): - def apply_field(obj, field, condition=None, init_field=None): - key = path + "/" + field - - if getattr(obj, 'custom_script_source', None) is not None: - key = 'customscript/' + obj.custom_script_source + '/' + key - - if getattr(obj, 'do_not_save_to_config', False): - return - - saved_value = ui_settings.get(key, None) - if saved_value is None: - ui_settings[key] = getattr(obj, field) - elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') - else: - setattr(obj, field, saved_value) - if init_field is not None: - init_field(saved_value) - - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: - apply_field(x, 'visible') - - if type(x) == gr.Slider: - apply_field(x, 'value') - apply_field(x, 'minimum') - apply_field(x, 'maximum') - apply_field(x, 'step') - - if type(x) == gr.Radio: - apply_field(x, 'value', lambda val: val in x.choices) - - if type(x) == gr.Checkbox: - apply_field(x, 'value') - - if type(x) == gr.Textbox: - apply_field(x, 'value') - - if type(x) == gr.Number: - apply_field(x, 'value') - - if type(x) == gr.Dropdown: - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - - visit(txt2img_interface, loadsave, "txt2img") - visit(img2img_interface, loadsave, "img2img") - visit(extras_interface, loadsave, "extras") - visit(modelmerger_interface, loadsave, "modelmerger") - visit(train_interface, loadsave, "train") - - if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): - with open(ui_config_file, "w", encoding="utf8") as file: - json.dump(ui_settings, file, indent=4) - - return demo - - -def reload_javascript(): - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' - - scripts_list = modules.scripts.list_scripts("javascript", ".js") - - for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" - - if cmd_opts.theme is not None: - javascript += f"\n\n" - - javascript += f"\n" - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse - - -def versions_html(): - import torch - import launch - - python_version = ".".join([str(x) for x in sys.version_info[0:3]]) - commit = launch.commit_hash() - short_commit = commit[0:8] - - if shared.xformers_available: - import xformers - xformers_version = xformers.__version__ - else: - xformers_version = "N/A" - - return f""" -python: {python_version} - •  -torch: {torch.__version__} - •  -xformers: {xformers_version} - •  -gradio: {gr.__version__} - •  -commit: {short_commit} -""" -- cgit v1.2.3 From 54dd5d6634ead25311a8bea0484675607601a21a Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Tue, 10 Jan 2023 11:54:49 +0300 Subject: Split history ui.py to ui_progress.py --- modules/temp | 1928 --------------------------------------------------------- modules/ui.py | 1928 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1928 insertions(+), 1928 deletions(-) delete mode 100644 modules/temp create mode 100644 modules/ui.py (limited to 'modules') diff --git a/modules/temp b/modules/temp deleted file mode 100644 index 9b9081b5..00000000 --- a/modules/temp +++ /dev/null @@ -1,1928 +0,0 @@ -import html -import json -import math -import mimetypes -import os -import platform -import random -import subprocess as sp -import sys -import tempfile -import time -import traceback -from functools import partial, reduce - -import gradio as gr -import gradio.routes -import gradio.utils -import numpy as np -from PIL import Image, PngImagePlugin -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call - -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru -from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path - -from modules.shared import opts, cmd_opts, restricted_opts - -import modules.codeformer_model -import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.scripts -import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui -from modules import prompt_parser -from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.textual_inversion import textual_inversion -import modules.hypernetworks.ui -from modules.generation_parameters_copypaste import image_from_url_text - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -if not cmd_opts.share and not cmd_opts.listen: - # fix gradio phoning home - gradio.utils.version_check = lambda: None - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' - -if cmd_opts.ngrok is not None: - import modules.ngrok as ngrok - print('ngrok authtoken detected, trying to connect...') - ngrok.connect( - cmd_opts.ngrok, - cmd_opts.port if cmd_opts.port is not None else 7860, - cmd_opts.ngrok_region - ) - - -def gr_show(visible=True): - return {"visible": visible, "__type__": "update"} - - -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None - -css_hide_progressbar = """ -.wrap .m-12 svg { display:none!important; } -.wrap .m-12::before { content:"Loading..." } -.wrap .z-20 svg { display:none!important; } -.wrap .z-20::before { content:"Loading..." } -.progress-bar { display:none!important; } -.meta-text { display:none!important; } -.meta-text-center { display:none!important; } -""" - -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -random_symbol = '\U0001f3b2\ufe0f' # 🎲️ -reuse_symbol = '\u267b\ufe0f' # ♻️ -paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ - - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text - -def send_gradio_gallery_to_image(x): - if len(x) == 0: - return None - return image_from_url_text(x[0]) - -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr_show(False) - preview_visibility = gr_show(False) - - if opts.show_progress_every_n_steps != 0: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr_show(True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr_show(False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def visit(x, func, path=""): - if hasattr(x, 'children'): - for c in x.children: - visit(c, func, path) - elif x.label is not None: - func(path + "/" + str(x.label), x) - - -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] - - -def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): - from modules import processing, devices - - if not enable: - return "" - - p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) - - with devices.autocast(): - p.init([""], [0], [0]) - - return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" - - -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] - - -def interrogate(image): - prompt = shared.interrogator.interrogate(image.convert("RGB")) - - return gr_show(True) if prompt is None else prompt - - -def interrogate_deepbooru(image): - prompt = deepbooru.model.tag(image) - return gr_show(True) if prompt is None else prompt - - -def create_seed_inputs(target_interface): - with FormRow(elem_id=target_interface + '_seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') - seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') - - with gr.Group(elem_id=target_interface + '_subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) - - # Components to show/hide based on the 'Extra' checkbox - seed_extras = [] - - with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: - seed_extras.append(seed_extra_row_1) - subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') - subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') - - with FormRow(visible=False) as seed_extra_row_2: - seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') - - random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) - random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) - - def change_visibility(show): - return {comp: gr_show(show) for comp in seed_extras} - - seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) - - return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - - - -def connect_clear_prompt(button): - """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" - button.click( - _js="clear_prompt", - fn=None, - inputs=[], - outputs=[], - ) - - -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): - """ Connects a 'reuse (sub)seed' button's click event so that it copies last used - (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength - was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" - def copy_seed(gen_info_string: str, index): - res = -1 - - try: - gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError as e: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) - - return [res, gr_show(False)] - - reuse_seed.click( - fn=copy_seed, - _js="(x, y) => [x, selected_gallery_index()]", - show_progress=False, - inputs=[generation_info, dummy_component], - outputs=[seed, dummy_component] - ) - - -def update_token_counter(text, steps): - try: - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] - - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" - - -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - - with gr.Row(elem_id="toprow"): - with gr.Column(scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, - placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, - placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_id="interrogate_col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1): - with gr.Row(): - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - -def apply_setting(key, value): - if value is None: - return gr.update() - - if shared.cmd_opts.freeze_settings: - return gr.update() - - # dont allow model to be swapped when model hash exists in prompt - if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: - return gr.update() - - if key == "sd_model_checkpoint": - ckpt_info = sd_models.get_closet_checkpoint_match(value) - - if ckpt_info is not None: - value = ckpt_info.title - else: - return gr.update() - - comp_args = opts.data_labels[key].component_args - if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: - return - - valtype = type(opts.data_labels[key].default) - oldval = opts.data.get(key, None) - opts.data[key] = valtype(value) if valtype != type(None) else value - if oldval != value and opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() - - opts.save(shared.config_filename) - return value - - -def update_generation_info(args): - generation_info, html_info, img_index = args - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info - - -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - -def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel'): - with gr.Group(): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log - - -def create_sampler_and_steps_selection(choices, tabname): - if opts.samplers_in_dropdown: - with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - else: - with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - - return steps, sampler_index - - -def ordered_ui_categories(): - user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} - - for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): - yield category - - -def create_ui(): - import modules.img2img - import modules.txt2img - - reload_javascript() - - parameters_copypaste.reset() - - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) - - with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) - - dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel', elem_id="txt2img_settings"): - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "cfg": - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - - elif category == "checkboxes": - with FormRow(elem_id="txt2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) - - elif category == "hires_fix": - with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: - with FormRow(elem_id="txt2img_hires_fix_row1"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - - with FormRow(elem_id="txt2img_hires_fix_row2"): - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - - hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - for input in hr_resolution_preview_inputs: - input.change( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False, - ) - input.change( - None, - _js="onCalcResolutionHires", - inputs=hr_resolution_preview_inputs, - outputs=[], - show_progress=False, - ) - - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - txt2img_args = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), - _js="submit", - inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) - - txt_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ] - ) - - enable_hr.change( - fn=lambda x: gr_show(x), - inputs=[enable_hr], - outputs=[hr_options], - show_progress = False, - ) - - txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - *modules.scripts.scripts_txt2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) - - txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, - steps, - sampler_index, - cfg_scale, - seed, - width, - height, - ] - - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) - - with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') - - with FormRow().style(equal_height=False): - with gr.Column(variant='panel', elem_id="img2img_settings"): - - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) - - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) - - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") - - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "cfg": - with FormGroup(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') - - elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() - - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ] - ) - - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - - img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), - _js="submit_img2img", - inputs=[ - dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, - init_img, - init_img_with_mask, - init_img_with_mask_orig, - init_img_inpaint, - init_mask_inpaint, - mask_mode, - steps, - sampler_index, - mask_blur, - mask_alpha, - inpainting_fill, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - denoising_strength, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - ] + custom_inputs, - outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) - - img2img_interrogate.click( - fn=interrogate, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], - ) - - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - - img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (mask_blur, "Mask blur"), - *modules.scripts.scripts_img2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) - - modules.scripts.scripts_current = None - - with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=modules.extras.clear_cache, - inputs=[], outputs=[] - ) - - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") - - with gr.Column(variant='panel'): - html = gr.HTML() - generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") - html2 = gr.HTML() - with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) - - image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), - inputs=[image], - outputs=[html, generation_info, html2], - ) - - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") - - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") - - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") - - custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - - with gr.Row(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") - save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): - gr.HTML(value="

See wiki for detailed explanation.

") - - with gr.Row().style(equal_height=False): - with gr.Tabs(elem_id="train_tabs"): - - with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") - initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") - - with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") - new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - - with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - def get_textual_inversion_template_names(): - return sorted([x for x in textual_inversion.textual_inversion_templates]) - - with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with FormRow(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - - with FormRow(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - - with FormRow(): - clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) - clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) - - with FormRow(): - batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") - - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - - with FormRow(): - template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) - create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") - - training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") - training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") - steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") - - with FormRow(): - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") - - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") - - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") - - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") - - with gr.Row(): - train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") - interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") - - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) - - script_callbacks.ui_train_tabs_callback(params) - - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) - ti_progress = gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) - - create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout, - new_hypernetwork_dropout_structure - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_embedding_name, - embedding_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {str(t)} for key {key}') - - elem_id = "setting_"+key - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - with FormRow(): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - - components = [] - component_dict = {} - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return gr.update(value=value), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - with gr.Row(): - with gr.Column(scale=6): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - with gr.Column(): - restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") - - result = gr.HTML(elem_id="settings_result") - - quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} - - quicksettings_list = [] - - previous_section = None - current_tab = None - with gr.Tabs(elem_id="settings"): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - elem_id, text = item.section - - if current_tab is not None: - current_tab.__exit__() - - current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) - current_tab.__enter__() - - previous_section = item.section - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - - if current_tab is not None: - current_tab.__exit__() - - with gr.TabItem("Actions"): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") - - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - - restart_gradio.click( - fn=request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) - - interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "ti"), - ] - - css = "" - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - with open(cssfile, "r", encoding="utf8") as file: - css += file.read() + "\n" - - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - css += file.read() + "\n" - - if not cmd_opts.no_progressbar_hiding: - css += css_hide_progressbar - - interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] - - extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] - - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings"): - for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component - - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() - - with gr.Tabs(elem_id="tabs") as tabs: - for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): - interface.render() - - if os.path.exists(os.path.join(script_path, "notification.mp3")): - audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") - - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for i, k, item in quicksettings_list: - component = component_dict[k] - - component.change( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [getattr(opts, key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - ) - - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - custom_name, - checkpoint_format, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) - - ui_config_file = cmd_opts.ui_config_file - ui_settings = {} - settings_count = len(ui_settings) - error_loading = False - - try: - if os.path.exists(ui_config_file): - with open(ui_config_file, "r", encoding="utf8") as file: - ui_settings = json.load(file) - except Exception: - error_loading = True - print("Error loading settings:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - def loadsave(path, x): - def apply_field(obj, field, condition=None, init_field=None): - key = path + "/" + field - - if getattr(obj, 'custom_script_source', None) is not None: - key = 'customscript/' + obj.custom_script_source + '/' + key - - if getattr(obj, 'do_not_save_to_config', False): - return - - saved_value = ui_settings.get(key, None) - if saved_value is None: - ui_settings[key] = getattr(obj, field) - elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') - else: - setattr(obj, field, saved_value) - if init_field is not None: - init_field(saved_value) - - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: - apply_field(x, 'visible') - - if type(x) == gr.Slider: - apply_field(x, 'value') - apply_field(x, 'minimum') - apply_field(x, 'maximum') - apply_field(x, 'step') - - if type(x) == gr.Radio: - apply_field(x, 'value', lambda val: val in x.choices) - - if type(x) == gr.Checkbox: - apply_field(x, 'value') - - if type(x) == gr.Textbox: - apply_field(x, 'value') - - if type(x) == gr.Number: - apply_field(x, 'value') - - if type(x) == gr.Dropdown: - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - - visit(txt2img_interface, loadsave, "txt2img") - visit(img2img_interface, loadsave, "img2img") - visit(extras_interface, loadsave, "extras") - visit(modelmerger_interface, loadsave, "modelmerger") - visit(train_interface, loadsave, "train") - - if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): - with open(ui_config_file, "w", encoding="utf8") as file: - json.dump(ui_settings, file, indent=4) - - return demo - - -def reload_javascript(): - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' - - scripts_list = modules.scripts.list_scripts("javascript", ".js") - - for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" - - if cmd_opts.theme is not None: - javascript += f"\n\n" - - javascript += f"\n" - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse - - -def versions_html(): - import torch - import launch - - python_version = ".".join([str(x) for x in sys.version_info[0:3]]) - commit = launch.commit_hash() - short_commit = commit[0:8] - - if shared.xformers_available: - import xformers - xformers_version = xformers.__version__ - else: - xformers_version = "N/A" - - return f""" -python: {python_version} - •  -torch: {torch.__version__} - •  -xformers: {xformers_version} - •  -gradio: {gr.__version__} - •  -commit: {short_commit} -""" diff --git a/modules/ui.py b/modules/ui.py new file mode 100644 index 00000000..9b9081b5 --- /dev/null +++ b/modules/ui.py @@ -0,0 +1,1928 @@ +import html +import json +import math +import mimetypes +import os +import platform +import random +import subprocess as sp +import sys +import tempfile +import time +import traceback +from functools import partial, reduce + +import gradio as gr +import gradio.routes +import gradio.utils +import numpy as np +from PIL import Image, PngImagePlugin +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call + +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML +from modules.paths import script_path + +from modules.shared import opts, cmd_opts, restricted_opts + +import modules.codeformer_model +import modules.generation_parameters_copypaste as parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.scripts +import modules.shared as shared +import modules.styles +import modules.textual_inversion.ui +from modules import prompt_parser +from modules.images import save_image +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img +from modules.textual_inversion import textual_inversion +import modules.hypernetworks.ui +from modules.generation_parameters_copypaste import image_from_url_text + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +if not cmd_opts.share and not cmd_opts.listen: + # fix gradio phoning home + gradio.utils.version_check = lambda: None + gradio.utils.get_local_ip_address = lambda: '127.0.0.1' + +if cmd_opts.ngrok is not None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect( + cmd_opts.ngrok, + cmd_opts.port if cmd_opts.port is not None else 7860, + cmd_opts.ngrok_region + ) + + +def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + +sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None + +css_hide_progressbar = """ +.wrap .m-12 svg { display:none!important; } +.wrap .m-12::before { content:"Loading..." } +.wrap .z-20 svg { display:none!important; } +.wrap .z-20::before { content:"Loading..." } +.progress-bar { display:none!important; } +.meta-text { display:none!important; } +.meta-text-center { display:none!important; } +""" + +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +random_symbol = '\U0001f3b2\ufe0f' # 🎲️ +reuse_symbol = '\u267b\ufe0f' # ♻️ +paste_symbol = '\u2199\ufe0f' # ↙ +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 +clear_prompt_symbol = '\U0001F5D1' # 🗑️ + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + return image_from_url_text(x[0]) + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = opts.outdir_save + save_to_dirs = opts.use_save_to_dirs_for_ui + extension: str = opts.samples_format + start_index = 0 + + if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(opts.outdir_save, exist_ok=True) + + with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def calc_time_left(progress, threshold, label, force_display, show_eta): + if progress == 0: + return "" + else: + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + eta_relative = eta-time_since_start + if (eta_relative > threshold and show_eta) or force_display: + if eta_relative > 3600: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + elif eta_relative > 60: + return label + time.strftime('%M:%S', time.gmtime(eta_relative)) + else: + return label + time.strftime('%Ss', time.gmtime(eta_relative)) + else: + return "" + + +def check_progress_call(id_part): + if shared.state.job_count == 0: + return "", gr_show(False), gr_show(False), gr_show(False) + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + # Show progress percentage and time left at the same moment, and base it also on steps done + show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 + + time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) + if time_left != "": + shared.state.time_left_force_display = True + + progress = min(progress, 1) + + progressbar = "" + if opts.show_progressbar: + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" + + image = gr_show(False) + preview_visibility = gr_show(False) + + if opts.show_progress_every_n_steps != 0: + shared.state.set_current_image() + image = shared.state.current_image + + if image is None: + image = gr.update(value=None) + else: + preview_visibility = gr_show(True) + + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result + + +def check_progress_call_initial(id_part): + shared.state.job_count = -1 + shared.state.current_latent = None + shared.state.current_image = None + shared.state.textinfo = None + shared.state.time_start = time.time() + shared.state.time_left_force_display = False + + return check_progress_call(id_part) + + +def visit(x, func, path=""): + if hasattr(x, 'children'): + for c in x.children: + visit(c, func, path) + elif x.label is not None: + func(path + "/" + str(x.label), x) + + +def add_style(name: str, prompt: str, negative_prompt: str): + if name is None: + return [gr_show() for x in range(4)] + + style = modules.styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we + # reserialize all styles every time we save them + shared.prompt_styles.save_styles(shared.styles_filename) + + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] + + +def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): + from modules import processing, devices + + if not enable: + return "" + + p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + + with devices.autocast(): + p.init([""], [0], [0]) + + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" + + +def apply_styles(prompt, prompt_neg, style1_name, style2_name): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] + + +def interrogate(image): + prompt = shared.interrogator.interrogate(image.convert("RGB")) + + return gr_show(True) if prompt is None else prompt + + +def interrogate_deepbooru(image): + prompt = deepbooru.model.tag(image) + return gr_show(True) if prompt is None else prompt + + +def create_seed_inputs(target_interface): + with FormRow(elem_id=target_interface + '_seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') + seed.style(container=False) + random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') + + with gr.Group(elem_id=target_interface + '_subseed_show_box'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) + + # Components to show/hide based on the 'Extra' checkbox + seed_extras = [] + + with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: + seed_extras.append(seed_extra_row_1) + subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') + subseed.style(container=False) + random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') + + with FormRow(visible=False) as seed_extra_row_2: + seed_extras.append(seed_extra_row_2) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') + + random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) + random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) + + def change_visibility(show): + return {comp: gr_show(show) for comp in seed_extras} + + seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) + + return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox + + + +def connect_clear_prompt(button): + """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" + button.click( + _js="clear_prompt", + fn=None, + inputs=[], + outputs=[], + ) + + +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): + """ Connects a 'reuse (sub)seed' button's click event so that it copies last used + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength + was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" + def copy_seed(gen_info_string: str, index): + res = -1 + + try: + gen_info = json.loads(gen_info_string) + index -= gen_info.get('index_of_first_image', 0) + + if is_subseed and gen_info.get('subseed_strength', 0) > 0: + all_subseeds = gen_info.get('all_subseeds', [-1]) + res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] + else: + all_seeds = gen_info.get('all_seeds', [-1]) + res = all_seeds[index if 0 <= index < len(all_seeds) else 0] + + except json.decoder.JSONDecodeError as e: + if gen_info_string != '': + print("Error parsing JSON generation info:", file=sys.stderr) + print(gen_info_string, file=sys.stderr) + + return [res, gr_show(False)] + + reuse_seed.click( + fn=copy_seed, + _js="(x, y) => [x, selected_gallery_index()]", + show_progress=False, + inputs=[generation_info, dummy_component], + outputs=[seed, dummy_component] + ) + + +def update_token_counter(text, steps): + try: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" + + +def create_toprow(is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + + with gr.Row(elem_id="toprow"): + with gr.Column(scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Column(scale=1, elem_id="roll_col"): + paste = gr.Button(value=paste_symbol, elem_id="paste") + save_style = gr.Button(value=save_style_symbol, elem_id="style_create") + prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + + clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[prompt, negative_prompt], + outputs=[prompt, negative_prompt], + ) + + button_interrogate = None + button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_id="interrogate_col"): + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + + with gr.Column(scale=1): + with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + with gr.Row(): + with gr.Column(scale=1, elem_id="style_pos_col"): + prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + with gr.Column(scale=1, elem_id="style_neg_col"): + prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + + +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) + check_progress.click( + fn=lambda: check_progress_call(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) + check_progress_initial.click( + fn=lambda: check_progress_call_initial(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + +def apply_setting(key, value): + if value is None: + return gr.update() + + if shared.cmd_opts.freeze_settings: + return gr.update() + + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data.get(key, None) + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return value + + +def update_generation_info(args): + generation_info, html_info, img_index = args + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info + return plaintext_to_html(generation_info["infotexts"][img_index]) + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_output_panel(tabname, outdir): + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel'): + with gr.Group(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="(x, y) => [x, y, selected_gallery_index()]", + inputs=[generation_info, html_info], + outputs=[html_info], + preprocess=False + ) + + save.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + save_zip.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + + +def create_sampler_and_steps_selection(choices, tabname): + if opts.samplers_in_dropdown: + with FormRow(elem_id=f"sampler_selection_{tabname}"): + sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + else: + with FormGroup(elem_id=f"sampler_selection_{tabname}"): + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + + return steps, sampler_index + + +def ordered_ui_categories(): + user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} + + for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): + yield category + + +def create_ui(): + import modules.img2img + import modules.txt2img + + reload_javascript() + + parameters_copypaste.reset() + + modules.scripts.scripts_current = modules.scripts.scripts_txt2img + modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + + with gr.Blocks(analytics_enabled=False) as txt2img_interface: + txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + + dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Row(elem_id='txt2img_progress_row'): + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="txt2img_progressbar") + txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) + setup_progressbar(progressbar, txt2img_preview, 'txt2img') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel', elem_id="txt2img_settings"): + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "cfg": + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') + + elif category == "checkboxes": + with FormRow(elem_id="txt2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) + + elif category == "hires_fix": + with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: + with FormRow(elem_id="txt2img_hires_fix_row1"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + with FormRow(elem_id="txt2img_hires_fix_row2"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="txt2img_script_container"): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] + for input in hr_resolution_preview_inputs: + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) + + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) + parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + txt2img_args = dict( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), + _js="submit", + inputs=[ + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + ] + custom_inputs, + + outputs=[ + txt2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + txt2img_prompt.submit(**txt2img_args) + submit.click(**txt2img_args) + + txt_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ] + ) + + enable_hr.change( + fn=lambda x: gr_show(x), + inputs=[enable_hr], + outputs=[hr_options], + show_progress = False, + ) + + txt2img_paste_fields = [ + (txt2img_prompt, "Prompt"), + (txt2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (enable_hr, lambda d: "Denoising strength" in d), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), + (hr_second_pass_steps, "Hires steps"), + (hr_resize_x, "Hires resize-1"), + (hr_resize_y, "Hires resize-2"), + *modules.scripts.scripts_txt2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) + + modules.scripts.scripts_current = modules.scripts.scripts_img2img + modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + + with gr.Blocks(analytics_enabled=False) as img2img_interface: + img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) + + with gr.Row(elem_id='img2img_progress_row'): + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="img2img_progressbar") + img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) + setup_progressbar(progressbar, img2img_preview, 'img2img') + + with FormRow().style(equal_height=False): + with gr.Column(variant='panel', elem_id="img2img_settings"): + + with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) + init_img_with_mask_orig = gr.State(None) + + use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" + if use_color_sketch: + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) + + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") + + with FormRow(): + mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "cfg": + with FormGroup(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + + elif category == "checkboxes": + with FormRow(elem_id="img2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="img2img_script_container"): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() + + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + img2img_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + img2img_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ] + ) + + mask_mode.change( + lambda mode, img: { + init_img_with_mask: gr_show(mode == 0), + init_img_inpaint: gr_show(mode == 1), + init_mask_inpaint: gr_show(mode == 1), + }, + inputs=[mask_mode, init_img_with_mask], + outputs=[ + init_img_with_mask, + init_img_inpaint, + init_mask_inpaint, + ], + ) + + img2img_args = dict( + fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), + _js="submit_img2img", + inputs=[ + dummy_component, + img2img_prompt, + img2img_negative_prompt, + img2img_prompt_style, + img2img_prompt_style2, + init_img, + init_img_with_mask, + init_img_with_mask_orig, + init_img_inpaint, + init_mask_inpaint, + mask_mode, + steps, + sampler_index, + mask_blur, + mask_alpha, + inpainting_fill, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + denoising_strength, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + resize_mode, + inpaint_full_res, + inpaint_full_res_padding, + inpainting_mask_invert, + img2img_batch_input_dir, + img2img_batch_output_dir, + ] + custom_inputs, + outputs=[ + img2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + img2img_prompt.submit(**img2img_args) + submit.click(**img2img_args) + + img2img_interrogate.click( + fn=interrogate, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] + style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] + + for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): + button.click( + fn=add_style, + _js="ask_for_style_name", + # Have to pass empty dummy component here, because the JavaScript and Python function have to accept + # the same number of parameters, but we only know the style-name after the JavaScript prompt + inputs=[dummy_component, prompt, negative_prompt], + outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], + ) + + for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + button.click( + fn=apply_styles, + _js=js_func, + inputs=[prompt, negative_prompt, style1, style2], + outputs=[prompt, negative_prompt, style1, style2], + ) + + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + + img2img_paste_fields = [ + (img2img_prompt, "Prompt"), + (img2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (mask_blur, "Mask blur"), + *modules.scripts.scripts_img2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + + modules.scripts.scripts_current = None + + with gr.Blocks(analytics_enabled=False) as extras_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab"): + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): + with gr.Group(): + with gr.Row(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with gr.Group(): + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + + with gr.Group(): + extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") + + with gr.Group(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") + + with gr.Group(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") + + with gr.Group(): + upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") + + result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) + + submit.click( + fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), + _js="get_extras_tab_index", + inputs=[ + dummy_component, + dummy_component, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + gfpgan_visibility, + codeformer_visibility, + codeformer_weight, + upscaling_resize, + upscaling_resize_w, + upscaling_resize_h, + upscaling_crop, + extras_upscaler_1, + extras_upscaler_2, + extras_upscaler_2_visibility, + upscale_before_face_fix, + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=modules.extras.clear_cache, + inputs=[], outputs=[] + ) + + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") + + with gr.Column(variant='panel'): + html = gr.HTML() + generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") + html2 = gr.HTML() + with gr.Row(): + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + parameters_copypaste.bind_buttons(buttons, image, generation_info) + + image.change( + fn=wrap_gradio_call(modules.extras.run_pnginfo), + inputs=[image], + outputs=[html, generation_info, html2], + ) + + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + + with gr.Row(): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") + + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") + + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") + + custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + + with gr.Row(): + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") + + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + + with gr.Column(variant='panel'): + submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + + with gr.Blocks(analytics_enabled=False) as train_interface: + with gr.Row().style(equal_height=False): + gr.HTML(value="

See wiki for detailed explanation.

") + + with gr.Row().style(equal_height=False): + with gr.Tabs(elem_id="train_tabs"): + + with gr.Tab(label="Create embedding"): + new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") + initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") + + with gr.Tab(label="Create hypernetwork"): + new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") + + with gr.Tab(label="Preprocess images"): + process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") + process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") + process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") + process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") + + with gr.Row(): + process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") + process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") + process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") + + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") + + with gr.Row(visible=False) as process_focal_crop_row: + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") + process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") + + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + + process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[process_focal_crop], + outputs=[process_focal_crop_row], + ) + + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + + with gr.Tab(label="Train"): + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") + with FormRow(): + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") + + with FormRow(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + with FormRow(): + clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) + + with FormRow(): + batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") + training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") + steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") + + with FormRow(): + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") + + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") + + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") + + with gr.Row(): + train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") + interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") + + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + create_hypernetwork.click( + fn=modules.hypernetworks.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + overwrite_old_hypernetwork, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout, + new_hypernetwork_dropout_structure + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_width, + process_height, + preprocess_txt_action, + process_flip, + process_split, + process_caption, + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + embedding_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + save_image_with_stored_embedding, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + hypernetwork_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {str(t)} for key {key}') + + elem_id = "setting_"+key + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + with FormRow(): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + components = [] + component_dict = {} + + script_callbacks.ui_settings_callback() + opts.reorder() + + def run_settings(*args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return gr.update(value=value), opts.dumpjson() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(scale=6): + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + result = gr.HTML(elem_id="settings_result") + + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} + + quicksettings_list = [] + + previous_section = None + current_tab = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_tab.__exit__() + + current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) + current_tab.__enter__() + + previous_section = item.section + + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + + if current_tab is not None: + current_tab.__exit__() + + with gr.TabItem("Actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + + if os.path.exists("html/licenses.html"): + with open("html/licenses.html", encoding="utf8") as file: + with gr.TabItem("Licenses"): + gr.HTML(file.read(), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + def request_restart(): + shared.state.interrupt() + shared.state.need_restart = True + + restart_gradio.click( + fn=request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (train_interface, "Train", "ti"), + ] + + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" + + if os.path.exists(os.path.join(script_path, "user.css")): + with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + css += file.read() + "\n" + + if not cmd_opts.no_progressbar_hiding: + css += css_hide_progressbar + + interfaces += script_callbacks.ui_tabs_callback() + interfaces += [(settings_interface, "Settings", "settings")] + + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Row(elem_id="quicksettings"): + for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): + component = create_setting_component(k, is_quicksettings=True) + component_dict[k] = component + + parameters_copypaste.integrate_settings_paste_fields(component_dict) + parameters_copypaste.run_bind() + + with gr.Tabs(elem_id="tabs") as tabs: + for interface, label, ifid in interfaces: + with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): + interface.render() + + if os.path.exists(os.path.join(script_path, "notification.mp3")): + audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) + + if os.path.exists("html/footer.html"): + with open("html/footer.html", encoding="utf8") as file: + footer = file.read() + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") + + text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + settings_submit.click( + fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), + inputs=components, + outputs=[text_settings, result], + ) + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + def get_settings_values(): + return [getattr(opts, key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + ) + + def modelmerger(*args): + try: + results = modules.extras.run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() # to remove the potentially missing models from the list + return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] + return results + + modelmerger_merge.click( + fn=modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + custom_name, + checkpoint_format, + ], + outputs=[ + submit_result, + primary_model_name, + secondary_model_name, + tertiary_model_name, + component_dict['sd_model_checkpoint'], + ] + ) + + ui_config_file = cmd_opts.ui_config_file + ui_settings = {} + settings_count = len(ui_settings) + error_loading = False + + try: + if os.path.exists(ui_config_file): + with open(ui_config_file, "r", encoding="utf8") as file: + ui_settings = json.load(file) + except Exception: + error_loading = True + print("Error loading settings:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def loadsave(path, x): + def apply_field(obj, field, condition=None, init_field=None): + key = path + "/" + field + + if getattr(obj, 'custom_script_source', None) is not None: + key = 'customscript/' + obj.custom_script_source + '/' + key + + if getattr(obj, 'do_not_save_to_config', False): + return + + saved_value = ui_settings.get(key, None) + if saved_value is None: + ui_settings[key] = getattr(obj, field) + elif condition and not condition(saved_value): + print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + else: + setattr(obj, field, saved_value) + if init_field is not None: + init_field(saved_value) + + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: + apply_field(x, 'visible') + + if type(x) == gr.Slider: + apply_field(x, 'value') + apply_field(x, 'minimum') + apply_field(x, 'maximum') + apply_field(x, 'step') + + if type(x) == gr.Radio: + apply_field(x, 'value', lambda val: val in x.choices) + + if type(x) == gr.Checkbox: + apply_field(x, 'value') + + if type(x) == gr.Textbox: + apply_field(x, 'value') + + if type(x) == gr.Number: + apply_field(x, 'value') + + if type(x) == gr.Dropdown: + apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) + + visit(txt2img_interface, loadsave, "txt2img") + visit(img2img_interface, loadsave, "img2img") + visit(extras_interface, loadsave, "extras") + visit(modelmerger_interface, loadsave, "modelmerger") + visit(train_interface, loadsave, "train") + + if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): + with open(ui_config_file, "w", encoding="utf8") as file: + json.dump(ui_settings, file, indent=4) + + return demo + + +def reload_javascript(): + with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: + javascript = f'' + + scripts_list = modules.scripts.list_scripts("javascript", ".js") + + for basedir, filename, path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: + javascript += f"\n" + + if cmd_opts.theme is not None: + javascript += f"\n\n" + + javascript += f"\n" + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gradio.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse + + +def versions_html(): + import torch + import launch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = launch.commit_hash() + short_commit = commit[0:8] + + if shared.xformers_available: + import xformers + xformers_version = xformers.__version__ + else: + xformers_version = "N/A" + + return f""" +python: {python_version} + •  +torch: {torch.__version__} + •  +xformers: {xformers_version} + •  +gradio: {gr.__version__} + •  +commit: {short_commit} +""" -- cgit v1.2.3 From ef75c980536471c0729a2319440e3083cd57a4f0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 12:29:45 +0300 Subject: Split history ui.py to ui_progress.py --- modules/ui.py | 94 +-- modules/ui_progress.py | 1839 +----------------------------------------------- 2 files changed, 9 insertions(+), 1924 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 9b9081b5..3c458ce8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -162,79 +162,6 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr_show(False) - preview_visibility = gr_show(False) - - if opts.show_progress_every_n_steps != 0: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr_show(True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr_show(False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - def visit(x, func, path=""): if hasattr(x, 'children'): for c in x.children: @@ -456,25 +383,10 @@ def create_toprow(is_img2img): return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) +def setup_progressbar(*args, **kwargs): + import modules.ui_progress - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) + modules.ui_progress.setup_progressbar(*args, **kwargs) def apply_setting(key, value): diff --git a/modules/ui_progress.py b/modules/ui_progress.py index 9b9081b5..592fda55 100644 --- a/modules/ui_progress.py +++ b/modules/ui_progress.py @@ -1,165 +1,10 @@ -import html -import json -import math -import mimetypes -import os -import platform -import random -import subprocess as sp -import sys -import tempfile import time -import traceback -from functools import partial, reduce import gradio as gr -import gradio.routes -import gradio.utils -import numpy as np -from PIL import Image, PngImagePlugin -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru -from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path +from modules.shared import opts -from modules.shared import opts, cmd_opts, restricted_opts - -import modules.codeformer_model -import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.scripts import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui -from modules import prompt_parser -from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.textual_inversion import textual_inversion -import modules.hypernetworks.ui -from modules.generation_parameters_copypaste import image_from_url_text - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -if not cmd_opts.share and not cmd_opts.listen: - # fix gradio phoning home - gradio.utils.version_check = lambda: None - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' - -if cmd_opts.ngrok is not None: - import modules.ngrok as ngrok - print('ngrok authtoken detected, trying to connect...') - ngrok.connect( - cmd_opts.ngrok, - cmd_opts.port if cmd_opts.port is not None else 7860, - cmd_opts.ngrok_region - ) - - -def gr_show(visible=True): - return {"visible": visible, "__type__": "update"} - - -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None - -css_hide_progressbar = """ -.wrap .m-12 svg { display:none!important; } -.wrap .m-12::before { content:"Loading..." } -.wrap .z-20 svg { display:none!important; } -.wrap .z-20::before { content:"Loading..." } -.progress-bar { display:none!important; } -.meta-text { display:none!important; } -.meta-text-center { display:none!important; } -""" - -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -random_symbol = '\U0001f3b2\ufe0f' # 🎲️ -reuse_symbol = '\u267b\ufe0f' # ♻️ -paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ - - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text - -def send_gradio_gallery_to_image(x): - if len(x) == 0: - return None - return image_from_url_text(x[0]) - -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") def calc_time_left(progress, threshold, label, force_display, show_eta): @@ -182,7 +27,7 @@ def calc_time_left(progress, threshold, label, force_display, show_eta): def check_progress_call(id_part): if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) + return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) progress = 0 @@ -204,8 +49,8 @@ def check_progress_call(id_part): if opts.show_progressbar: progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - image = gr_show(False) - preview_visibility = gr_show(False) + image = gr.update(visible=False) + preview_visibility = gr.update(visible=False) if opts.show_progress_every_n_steps != 0: shared.state.set_current_image() @@ -214,12 +59,12 @@ def check_progress_call(id_part): if image is None: image = gr.update(value=None) else: - preview_visibility = gr_show(True) + preview_visibility = gr.update(visible=True) if shared.state.textinfo is not None: textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) else: - textinfo_result = gr_show(False) + textinfo_result = gr.update(visible=False) return f"

{progressbar}

", preview_visibility, image, textinfo_result @@ -235,227 +80,6 @@ def check_progress_call_initial(id_part): return check_progress_call(id_part) -def visit(x, func, path=""): - if hasattr(x, 'children'): - for c in x.children: - visit(c, func, path) - elif x.label is not None: - func(path + "/" + str(x.label), x) - - -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] - - -def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): - from modules import processing, devices - - if not enable: - return "" - - p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) - - with devices.autocast(): - p.init([""], [0], [0]) - - return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" - - -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] - - -def interrogate(image): - prompt = shared.interrogator.interrogate(image.convert("RGB")) - - return gr_show(True) if prompt is None else prompt - - -def interrogate_deepbooru(image): - prompt = deepbooru.model.tag(image) - return gr_show(True) if prompt is None else prompt - - -def create_seed_inputs(target_interface): - with FormRow(elem_id=target_interface + '_seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') - seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') - - with gr.Group(elem_id=target_interface + '_subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) - - # Components to show/hide based on the 'Extra' checkbox - seed_extras = [] - - with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: - seed_extras.append(seed_extra_row_1) - subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') - subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') - - with FormRow(visible=False) as seed_extra_row_2: - seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') - - random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) - random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) - - def change_visibility(show): - return {comp: gr_show(show) for comp in seed_extras} - - seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) - - return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - - - -def connect_clear_prompt(button): - """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" - button.click( - _js="clear_prompt", - fn=None, - inputs=[], - outputs=[], - ) - - -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): - """ Connects a 'reuse (sub)seed' button's click event so that it copies last used - (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength - was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" - def copy_seed(gen_info_string: str, index): - res = -1 - - try: - gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError as e: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) - - return [res, gr_show(False)] - - reuse_seed.click( - fn=copy_seed, - _js="(x, y) => [x, selected_gallery_index()]", - show_progress=False, - inputs=[generation_info, dummy_component], - outputs=[seed, dummy_component] - ) - - -def update_token_counter(text, steps): - try: - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] - - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" - - -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - - with gr.Row(elem_id="toprow"): - with gr.Column(scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, - placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, - placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_id="interrogate_col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1): - with gr.Row(): - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button - - def setup_progressbar(progressbar, preview, id_part, textinfo=None): if textinfo is None: textinfo = gr.HTML(visible=False) @@ -475,1454 +99,3 @@ def setup_progressbar(progressbar, preview, id_part, textinfo=None): inputs=[], outputs=[progressbar, preview, preview, textinfo], ) - - -def apply_setting(key, value): - if value is None: - return gr.update() - - if shared.cmd_opts.freeze_settings: - return gr.update() - - # dont allow model to be swapped when model hash exists in prompt - if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: - return gr.update() - - if key == "sd_model_checkpoint": - ckpt_info = sd_models.get_closet_checkpoint_match(value) - - if ckpt_info is not None: - value = ckpt_info.title - else: - return gr.update() - - comp_args = opts.data_labels[key].component_args - if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: - return - - valtype = type(opts.data_labels[key].default) - oldval = opts.data.get(key, None) - opts.data[key] = valtype(value) if valtype != type(None) else value - if oldval != value and opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() - - opts.save(shared.config_filename) - return value - - -def update_generation_info(args): - generation_info, html_info, img_index = args - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info - - -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - -def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel'): - with gr.Group(): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log - - -def create_sampler_and_steps_selection(choices, tabname): - if opts.samplers_in_dropdown: - with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - else: - with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - - return steps, sampler_index - - -def ordered_ui_categories(): - user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} - - for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): - yield category - - -def create_ui(): - import modules.img2img - import modules.txt2img - - reload_javascript() - - parameters_copypaste.reset() - - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) - - with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) - - dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel', elem_id="txt2img_settings"): - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "cfg": - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - - elif category == "checkboxes": - with FormRow(elem_id="txt2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) - - elif category == "hires_fix": - with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: - with FormRow(elem_id="txt2img_hires_fix_row1"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - - with FormRow(elem_id="txt2img_hires_fix_row2"): - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - - hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - for input in hr_resolution_preview_inputs: - input.change( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False, - ) - input.change( - None, - _js="onCalcResolutionHires", - inputs=hr_resolution_preview_inputs, - outputs=[], - show_progress=False, - ) - - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - txt2img_args = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), - _js="submit", - inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) - - txt_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ] - ) - - enable_hr.change( - fn=lambda x: gr_show(x), - inputs=[enable_hr], - outputs=[hr_options], - show_progress = False, - ) - - txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - *modules.scripts.scripts_txt2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) - - txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, - steps, - sampler_index, - cfg_scale, - seed, - width, - height, - ] - - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) - - with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') - - with FormRow().style(equal_height=False): - with gr.Column(variant='panel', elem_id="img2img_settings"): - - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) - - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) - - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") - - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "cfg": - with FormGroup(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') - - elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() - - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ] - ) - - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - - img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), - _js="submit_img2img", - inputs=[ - dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, - init_img, - init_img_with_mask, - init_img_with_mask_orig, - init_img_inpaint, - init_mask_inpaint, - mask_mode, - steps, - sampler_index, - mask_blur, - mask_alpha, - inpainting_fill, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - denoising_strength, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - ] + custom_inputs, - outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) - - img2img_interrogate.click( - fn=interrogate, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], - ) - - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - - img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (mask_blur, "Mask blur"), - *modules.scripts.scripts_img2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) - - modules.scripts.scripts_current = None - - with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=modules.extras.clear_cache, - inputs=[], outputs=[] - ) - - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") - - with gr.Column(variant='panel'): - html = gr.HTML() - generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") - html2 = gr.HTML() - with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) - - image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), - inputs=[image], - outputs=[html, generation_info, html2], - ) - - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") - - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") - - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") - - custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - - with gr.Row(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") - save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): - gr.HTML(value="

See wiki for detailed explanation.

") - - with gr.Row().style(equal_height=False): - with gr.Tabs(elem_id="train_tabs"): - - with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") - initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") - - with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") - new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - - with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - def get_textual_inversion_template_names(): - return sorted([x for x in textual_inversion.textual_inversion_templates]) - - with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with FormRow(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - - with FormRow(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - - with FormRow(): - clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) - clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) - - with FormRow(): - batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") - - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - - with FormRow(): - template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) - create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") - - training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") - training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") - steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") - - with FormRow(): - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") - - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") - - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") - - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") - - with gr.Row(): - train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") - interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") - - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) - - script_callbacks.ui_train_tabs_callback(params) - - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) - ti_progress = gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) - - create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout, - new_hypernetwork_dropout_structure - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_embedding_name, - embedding_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {str(t)} for key {key}') - - elem_id = "setting_"+key - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - with FormRow(): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - - components = [] - component_dict = {} - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return gr.update(value=value), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - with gr.Row(): - with gr.Column(scale=6): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - with gr.Column(): - restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") - - result = gr.HTML(elem_id="settings_result") - - quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} - - quicksettings_list = [] - - previous_section = None - current_tab = None - with gr.Tabs(elem_id="settings"): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - elem_id, text = item.section - - if current_tab is not None: - current_tab.__exit__() - - current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) - current_tab.__enter__() - - previous_section = item.section - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - - if current_tab is not None: - current_tab.__exit__() - - with gr.TabItem("Actions"): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") - - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - - restart_gradio.click( - fn=request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) - - interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "ti"), - ] - - css = "" - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - with open(cssfile, "r", encoding="utf8") as file: - css += file.read() + "\n" - - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - css += file.read() + "\n" - - if not cmd_opts.no_progressbar_hiding: - css += css_hide_progressbar - - interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] - - extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] - - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings"): - for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component - - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() - - with gr.Tabs(elem_id="tabs") as tabs: - for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): - interface.render() - - if os.path.exists(os.path.join(script_path, "notification.mp3")): - audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") - - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for i, k, item in quicksettings_list: - component = component_dict[k] - - component.change( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [getattr(opts, key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - ) - - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - custom_name, - checkpoint_format, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) - - ui_config_file = cmd_opts.ui_config_file - ui_settings = {} - settings_count = len(ui_settings) - error_loading = False - - try: - if os.path.exists(ui_config_file): - with open(ui_config_file, "r", encoding="utf8") as file: - ui_settings = json.load(file) - except Exception: - error_loading = True - print("Error loading settings:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - def loadsave(path, x): - def apply_field(obj, field, condition=None, init_field=None): - key = path + "/" + field - - if getattr(obj, 'custom_script_source', None) is not None: - key = 'customscript/' + obj.custom_script_source + '/' + key - - if getattr(obj, 'do_not_save_to_config', False): - return - - saved_value = ui_settings.get(key, None) - if saved_value is None: - ui_settings[key] = getattr(obj, field) - elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') - else: - setattr(obj, field, saved_value) - if init_field is not None: - init_field(saved_value) - - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: - apply_field(x, 'visible') - - if type(x) == gr.Slider: - apply_field(x, 'value') - apply_field(x, 'minimum') - apply_field(x, 'maximum') - apply_field(x, 'step') - - if type(x) == gr.Radio: - apply_field(x, 'value', lambda val: val in x.choices) - - if type(x) == gr.Checkbox: - apply_field(x, 'value') - - if type(x) == gr.Textbox: - apply_field(x, 'value') - - if type(x) == gr.Number: - apply_field(x, 'value') - - if type(x) == gr.Dropdown: - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - - visit(txt2img_interface, loadsave, "txt2img") - visit(img2img_interface, loadsave, "img2img") - visit(extras_interface, loadsave, "extras") - visit(modelmerger_interface, loadsave, "modelmerger") - visit(train_interface, loadsave, "train") - - if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): - with open(ui_config_file, "w", encoding="utf8") as file: - json.dump(ui_settings, file, indent=4) - - return demo - - -def reload_javascript(): - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' - - scripts_list = modules.scripts.list_scripts("javascript", ".js") - - for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" - - if cmd_opts.theme is not None: - javascript += f"\n\n" - - javascript += f"\n" - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse - - -def versions_html(): - import torch - import launch - - python_version = ".".join([str(x) for x in sys.version_info[0:3]]) - commit = launch.commit_hash() - short_commit = commit[0:8] - - if shared.xformers_available: - import xformers - xformers_version = xformers.__version__ - else: - xformers_version = "N/A" - - return f""" -python: {python_version} - •  -torch: {torch.__version__} - •  -xformers: {xformers_version} - •  -gradio: {gr.__version__} - •  -commit: {short_commit} -""" -- cgit v1.2.3 From 0c3feb202c5714abd50d879c1db2cd9a71ce93e3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 14:08:29 +0300 Subject: disable torch weight initialization and CLIP downloading/reading checkpoint to speedup creating sd model from config --- modules/sd_disable_initialization.py | 44 ++++++++++++++++++++++++++++++++++++ modules/sd_models.py | 5 ++-- 2 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 modules/sd_disable_initialization.py (limited to 'modules') diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py new file mode 100644 index 00000000..c9a3b5e4 --- /dev/null +++ b/modules/sd_disable_initialization.py @@ -0,0 +1,44 @@ +import ldm.modules.encoders.modules +import open_clip +import torch + + +class DisableInitialization: + """ + When an object of this class enters a `with` block, it starts preventing torch's layer initialization + functions from working, and changes CLIP and OpenCLIP to not download model weights. When it leaves, + reverts everything to how it was. + + Use like this: + ``` + with DisableInitialization(): + do_things() + ``` + """ + + def __enter__(self): + def do_nothing(*args, **kwargs): + pass + + def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs): + return self.create_model_and_transforms(*args, pretrained=None, **kwargs) + + def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): + return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) + + self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_ + self.init_no_grad_normal = torch.nn.init._no_grad_normal_ + self.create_model_and_transforms = open_clip.create_model_and_transforms + self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained + + torch.nn.init.kaiming_uniform_ = do_nothing + torch.nn.init._no_grad_normal_ = do_nothing + open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained + ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform + torch.nn.init._no_grad_normal_ = self.init_no_grad_normal + open_clip.create_model_and_transforms = self.create_model_and_transforms + ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained + diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a6d55ca..ee241032 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -13,7 +13,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -319,7 +319,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False - sd_model = instantiate_from_config(sd_config.model) + with sd_disable_initialization.DisableInitialization(): + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) -- cgit v1.2.3 From ce3f639ec8758ce2bc90483336361d2dc25acd3a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 16:51:04 +0300 Subject: add more stuff to ignore when creating model from config prevent .vae.safetensors files from being listed as stable diffusion models --- modules/modelloader.py | 4 +++- modules/sd_disable_initialization.py | 29 +++++++++++++++++++++++++---- modules/sd_models.py | 32 ++++++++++++++++++++++++++++---- 3 files changed, 56 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/modelloader.py b/modules/modelloader.py index 6a1a7ac8..e9aa514e 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -10,7 +10,7 @@ from modules.upscaler import Upscaler from modules.paths import script_path, models_path -def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list: +def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. @@ -45,6 +45,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None full_path = file if os.path.isdir(full_path): continue + if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]): + continue if len(ext_filter) != 0: model_name, extension = os.path.splitext(file) if extension not in ext_filter: diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index c9a3b5e4..9942bd7e 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -1,15 +1,19 @@ import ldm.modules.encoders.modules import open_clip import torch +import transformers.utils.hub class DisableInitialization: """ - When an object of this class enters a `with` block, it starts preventing torch's layer initialization - functions from working, and changes CLIP and OpenCLIP to not download model weights. When it leaves, - reverts everything to how it was. + When an object of this class enters a `with` block, it starts: + - preventing torch's layer initialization functions from working + - changes CLIP and OpenCLIP to not download model weights + - changes CLIP to not make requests to check if there is a new version of a file you already have - Use like this: + When it leaves the block, it reverts everything to how it was before. + + Use it like this: ``` with DisableInitialization(): do_things() @@ -26,19 +30,36 @@ class DisableInitialization: def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) + def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): + + # this file is always 404, prevent making request + if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json': + raise transformers.utils.hub.EntryNotFoundError + + try: + return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=True, **kwargs) + except Exception as e: + return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs) + self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_ self.init_no_grad_normal = torch.nn.init._no_grad_normal_ + self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_ self.create_model_and_transforms = open_clip.create_model_and_transforms self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained + self.transformers_utils_hub_get_from_cache = transformers.utils.hub.get_from_cache torch.nn.init.kaiming_uniform_ = do_nothing torch.nn.init._no_grad_normal_ = do_nothing + torch.nn.init._no_grad_uniform_ = do_nothing open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained + transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache def __exit__(self, exc_type, exc_val, exc_tb): torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform torch.nn.init._no_grad_normal_ = self.init_no_grad_normal + torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_ open_clip.create_model_and_transforms = self.create_model_and_transforms ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained + transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache diff --git a/modules/sd_models.py b/modules/sd_models.py index ee241032..1bb9088b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,6 +2,7 @@ import collections import os.path import sys import gc +import time from collections import namedtuple import torch import re @@ -61,7 +62,7 @@ def find_checkpoint_config(info): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -288,6 +289,17 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper +class Timer: + def __init__(self): + self.start = time.time() + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -319,11 +331,17 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False + timer = Timer() + with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) + elapsed_create = timer.elapsed() + load_model_weights(sd_model, checkpoint_info) + elapsed_load_weights = timer.elapsed() + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: @@ -338,7 +356,9 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) - print("Model loaded.") + elapsed_the_rest = timer.elapsed() + + print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") return sd_model @@ -349,7 +369,7 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model - if sd_model is None: # previous model load failed + if sd_model is None: # previous model load failed current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info @@ -371,6 +391,8 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) + timer = Timer() + try: load_model_weights(sd_model, checkpoint_info) except Exception as e: @@ -384,6 +406,8 @@ def reload_model_weights(sd_model=None, info=None): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print("Weights loaded.") + elapsed = timer.elapsed() + + print(f"Weights loaded in {elapsed:.1f}s.") return sd_model -- cgit v1.2.3 From 0f8603a55988d22616b17140e6c4a7e9d0736af5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 17:46:59 +0300 Subject: add support for transformers==4.25.1 add fallback for when quick model creation fails --- modules/sd_disable_initialization.py | 42 ++++++++++++++++++++++++++++++------ modules/sd_models.py | 8 +++++-- 2 files changed, 42 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 9942bd7e..088ac24b 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -30,30 +30,53 @@ class DisableInitialization: def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) - def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): + def transformers_modeling_utils_load_pretrained_model(*args, **kwargs): + args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug + return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs) + + def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs): # this file is always 404, prevent making request if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json': raise transformers.utils.hub.EntryNotFoundError try: - return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=True, **kwargs) + return original(url, *args, local_files_only=True, **kwargs) except Exception as e: - return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs) + return original(url, *args, local_files_only=False, **kwargs) + + def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): + return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs) + + def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs): + return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs) + + def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): + return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs) self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_ self.init_no_grad_normal = torch.nn.init._no_grad_normal_ self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_ self.create_model_and_transforms = open_clip.create_model_and_transforms self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained - self.transformers_utils_hub_get_from_cache = transformers.utils.hub.get_from_cache + self.transformers_modeling_utils_load_pretrained_model = getattr(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', None) + self.transformers_tokenization_utils_base_cached_file = getattr(transformers.tokenization_utils_base, 'cached_file', None) + self.transformers_configuration_utils_cached_file = getattr(transformers.configuration_utils, 'cached_file', None) + self.transformers_utils_hub_get_from_cache = getattr(transformers.utils.hub, 'get_from_cache', None) torch.nn.init.kaiming_uniform_ = do_nothing torch.nn.init._no_grad_normal_ = do_nothing torch.nn.init._no_grad_uniform_ = do_nothing open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained - transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache + if self.transformers_modeling_utils_load_pretrained_model is not None: + transformers.modeling_utils.PreTrainedModel._load_pretrained_model = transformers_modeling_utils_load_pretrained_model + if self.transformers_tokenization_utils_base_cached_file is not None: + transformers.tokenization_utils_base.cached_file = transformers_tokenization_utils_base_cached_file + if self.transformers_configuration_utils_cached_file is not None: + transformers.configuration_utils.cached_file = transformers_configuration_utils_cached_file + if self.transformers_utils_hub_get_from_cache is not None: + transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache def __exit__(self, exc_type, exc_val, exc_tb): torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform @@ -61,5 +84,12 @@ class DisableInitialization: torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_ open_clip.create_model_and_transforms = self.create_model_and_transforms ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained - transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache + if self.transformers_modeling_utils_load_pretrained_model is not None: + transformers.modeling_utils.PreTrainedModel._load_pretrained_model = self.transformers_modeling_utils_load_pretrained_model + if self.transformers_tokenization_utils_base_cached_file is not None: + transformers.utils.hub.cached_file = self.transformers_tokenization_utils_base_cached_file + if self.transformers_configuration_utils_cached_file is not None: + transformers.utils.hub.cached_file = self.transformers_configuration_utils_cached_file + if self.transformers_utils_hub_get_from_cache is not None: + transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache diff --git a/modules/sd_models.py b/modules/sd_models.py index 1bb9088b..b5bc12f0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -333,7 +333,11 @@ def load_model(checkpoint_info=None): timer = Timer() - with sd_disable_initialization.DisableInitialization(): + try: + with sd_disable_initialization.DisableInitialization(): + sd_model = instantiate_from_config(sd_config.model) + except Exception as e: + print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) elapsed_create = timer.elapsed() -- cgit v1.2.3 From 29fb5327640465fc83111e2170c5d8aa2b15266c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 23:47:02 +0300 Subject: change color selector in settings to be part of form --- modules/shared.py | 4 ++-- modules/ui_components.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index aa37c8ce..264264a6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors +from modules import localization, sd_vae, extensions, script_loading, errors, ui_components from modules.paths import models_path, script_path, sd_path @@ -387,7 +387,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), - "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}), + "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), diff --git a/modules/ui_components.py b/modules/ui_components.py index cac001dc..97acff06 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -31,3 +31,9 @@ class FormHTML(gr.HTML, gr.components.FormComponent): def get_block_name(self): return "html" + +class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): + """Same as gr.ColorPicker but fits inside gradio forms""" + + def get_block_name(self): + return "colorpicker" -- cgit v1.2.3 From 6be644fa04ce1542f3a01804310cbbc0a4a91620 Mon Sep 17 00:00:00 2001 From: dan Date: Wed, 11 Jan 2023 05:31:58 +0800 Subject: Enable batch_size>1 for mixed-sized training --- modules/textual_inversion/dataset.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index fa48708e..b47414f3 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -3,8 +3,10 @@ import numpy as np import PIL import torch from PIL import Image -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import transforms +from collections import defaultdict +from random import shuffle, choices import random import tqdm @@ -45,12 +47,12 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - assert batch_size == 1 or not varsize, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.shuffle_tags = shuffle_tags self.tag_drop_out = tag_drop_out + groups = defaultdict(list) print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): @@ -103,13 +105,14 @@ class PersonalizedBase(Dataset): if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): with devices.autocast(): entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) - + groups[image.size].append(len(self.dataset)) self.dataset.append(entry) del torchdata del latent_dist del latent_sample self.length = len(self.dataset) + self.groups = list(groups.values()) assert self.length > 0, "No images have been found in the dataset." self.batch_size = min(batch_size, self.length) self.gradient_step = min(gradient_step, self.length // self.batch_size) @@ -137,9 +140,34 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) return entry +class GroupedBatchSampler(Sampler): + def __init__(self, data_source: PersonalizedBase, batch_size: int): + n = len(data_source) + self.groups = data_source.groups + self.len = n_batch = n // batch_size + expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] + self.base = [int(e) // batch_size for e in expected] + self.n_rand_batches = nrb = n_batch - sum(self.base) + self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] + self.batch_size = batch_size + def __len__(self): + return self.len + def __iter__(self): + b = self.batch_size + for g in self.groups: + shuffle(g) + batches = [] + for g in self.groups: + batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) + for _ in range(self.n_rand_batches): + rand_group = choices(self.groups, self.probs)[0] + batches.append(choices(rand_group, k=b)) + shuffle(batches) + yield from batches + class PersonalizedDataLoader(DataLoader): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): - super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory) + super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) if latent_sampling_method == "random": self.collate_fn = collate_wrapper_random else: -- cgit v1.2.3 From f9706acf431f77e0ce9e4270e5be7299922ee963 Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Tue, 10 Jan 2023 18:40:34 -0700 Subject: Support loading textual inversion embeddings from safetensors files --- modules/textual_inversion/textual_inversion.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5420903f..3866c154 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -9,6 +9,7 @@ import tqdm import html import datetime import csv +import safetensors.torch from PIL import Image, PngImagePlugin @@ -150,6 +151,8 @@ class EmbeddingDatabase: name = data.get('name', name) elif ext in ['.BIN', '.PT']: data = torch.load(path, map_location="cpu") + elif ext in ['.SAFETENSORS']: + data = safetensors.torch.load_file(path, device="cpu") else: return -- cgit v1.2.3 From 5830095b73515fc49b3fd567048470005191ec34 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 10 Jan 2023 21:43:24 -0500 Subject: Add old prompt parser compat option --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 264264a6..b61bbd3f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -400,6 +400,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), + "use_old_prompt_parser_default_step_transformer": OptionInfo(False, "Use old prompt parser default step transformer. In particular, alternating words that contained emphasis were not parsed correctly. Useful to reproduce old seeds."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { -- cgit v1.2.3 From 7e45fba55b24166501033a221e6268545fa47fbe Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 10 Jan 2023 21:47:03 -0500 Subject: Fix prompt parser default step transformer w/ test --- modules/prompt_parser.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f70872c4..b69f1425 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -3,6 +3,11 @@ from collections import namedtuple from typing import List import lark +try: + from modules.shared import opts +except: + pass + # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] @@ -49,6 +54,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): [[5, 'a c'], [10, 'a {b|d{ c']] >>> g("((a][:b:c [d:3]") [[3, '((a][:b:c '], [10, '((a][:b:c d']] + >>> g("[a|(b:1.1)]") + [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']] """ def collect_steps(steps, tree): @@ -84,7 +91,13 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): yield args[0].value def __default__(self, data, children, meta): for child in children: - yield from child + try: + if opts.use_old_prompt_parser_default_step_transformer: + yield from child + else: + yield child + except: + yield child return AtStep().transform(tree) def get_schedule(prompt): -- cgit v1.2.3 From 37a230112198adcb3f24d59b399cff342a6d479e Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 10 Jan 2023 20:30:09 -0800 Subject: Expose the compiled class module of scripts to extensions --- modules/scripts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 35164093..4ffc369b 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -152,7 +152,7 @@ def basedir(): scripts_data = [] ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) -ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"]) +ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) def list_scripts(scriptdirname, extension): @@ -206,7 +206,7 @@ def load_scripts(): for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir)) + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) except Exception: print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) @@ -241,7 +241,7 @@ class ScriptRunner: self.alwayson_scripts.clear() self.selectable_scripts.clear() - for script_class, path, basedir in scripts_data: + for script_class, path, basedir, script_module in scripts_data: script = script_class() script.filename = path script.is_txt2img = not is_img2img -- cgit v1.2.3 From 954091697fce7a1b7997d5f3d73551f793f6bebc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 09:10:07 +0300 Subject: add an option to copy config from one of models in checkpoint merger --- modules/extras.py | 30 +++++++++++++++++++++++++++++- modules/ui.py | 9 ++++++--- 2 files changed, 35 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 7407bfe3..a03d558e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -3,6 +3,7 @@ import math import os import sys import traceback +import shutil import numpy as np from PIL import Image @@ -248,7 +249,32 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): +def create_config(ckpt_result, config_source, a, b, c): + def config(x): + return sd_models.find_checkpoint_config(x) if x else None + + if config_source == 0: + cfg = config(a) or config(b) or config(c) + elif config_source == 1: + cfg = config(b) + elif config_source == 2: + cfg = config(c) + else: + cfg = None + + if cfg is None: + return + + filename, _ = os.path.splitext(ckpt_result) + checkpoint_filename = filename + ".yaml" + + print("Copying config:") + print(" from:", cfg) + print(" to:", checkpoint_filename) + shutil.copyfile(cfg, checkpoint_filename) + + +def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): shared.state.begin() shared.state.job = 'model-merge' @@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam sd_models.list_models() + create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) + print("Checkpoint saved.") shared.state.textinfo = "Checkpoint saved to " + output_modelname shared.state.end() diff --git a/modules/ui.py b/modules/ui.py index 3c458ce8..82f5dd7c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1129,7 +1129,7 @@ def create_ui(): with gr.Column(variant='panel'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - with gr.Row(): + with FormRow(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") @@ -1143,11 +1143,13 @@ def create_ui(): interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - with gr.Row(): + with FormRow(): checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) @@ -1703,6 +1705,7 @@ def create_ui(): save_as_half, custom_name, checkpoint_format, + config_source, ], outputs=[ submit_result, -- cgit v1.2.3 From 4fdacd31e48c6a7a35c1c25c559932585e8addde Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 10:24:56 +0300 Subject: possible fix for fallback for fast model creation from config --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index b5bc12f0..a0a8a909 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -337,6 +337,9 @@ def load_model(checkpoint_info=None): with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) except Exception as e: + pass + + if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.3 From 1a23dc32ac5e16fac10115cafd0b841abd06e59f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 10:34:36 +0300 Subject: possible fix for fallback for fast model creation from config, attempt 2 --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index a0a8a909..084ba7fa 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -333,6 +333,7 @@ def load_model(checkpoint_info=None): timer = Timer() + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.3 From ab388d6f8bf51338de1950b3907c324b0ff6a872 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 11 Jan 2023 08:59:47 -0500 Subject: Remove compat option check for prompt parser --- modules/prompt_parser.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index b69f1425..870218db 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -3,11 +3,6 @@ from collections import namedtuple from typing import List import lark -try: - from modules.shared import opts -except: - pass - # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] @@ -91,13 +86,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): yield args[0].value def __default__(self, data, children, meta): for child in children: - try: - if opts.use_old_prompt_parser_default_step_transformer: - yield from child - else: - yield child - except: - yield child + yield child return AtStep().transform(tree) def get_schedule(prompt): -- cgit v1.2.3 From 0b38b72d31ead82c7d0998a29e50da90073831f7 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 11 Jan 2023 09:01:37 -0500 Subject: Remove compat option for prompt parser --- modules/shared.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index b61bbd3f..264264a6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -400,7 +400,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), - "use_old_prompt_parser_default_step_transformer": OptionInfo(False, "Use old prompt parser default step transformer. In particular, alternating words that contained emphasis were not parsed correctly. Useful to reproduce old seeds."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { -- cgit v1.2.3 From 39ea251945d70efcf9b59d44eb0e71269d754aa4 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 11 Jan 2023 10:23:51 -0500 Subject: add textinfo to progress response --- modules/api/api.py | 4 ++-- modules/api/models.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 6c564ad8..5767ba90 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -286,7 +286,7 @@ class Api: # copy from check_progress_call of ui.py if shared.state.job_count == 0: - return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict()) + return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo) # avoid dividing zero progress = 0.01 @@ -308,7 +308,7 @@ class Api: if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) - return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) + return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) def interrogateapi(self, interrogatereq: InterrogateRequest): image_b64 = interrogatereq.image diff --git a/modules/api/models.py b/modules/api/models.py index 034b4aa0..c78095ca 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -168,6 +168,7 @@ class ProgressResponse(BaseModel): eta_relative: float = Field(title="ETA in secs") state: dict = Field(title="State", description="The current state snapshot") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") + textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.") class InterrogateRequest(BaseModel): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") -- cgit v1.2.3 From 3f43d8a966ba8462ba019a5ad573f94508cd45f8 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 11 Jan 2023 10:28:55 -0500 Subject: set descriptions --- modules/hypernetworks/hypernetwork.py | 4 +++- modules/textual_inversion/preprocess.py | 7 ++++++- modules/textual_inversion/textual_inversion.py | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 300d3975..194679e8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -619,7 +619,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, epoch_num = hypernetwork.step // steps_per_epoch epoch_step = hypernetwork.step % steps_per_epoch - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" + pbar.set_description(description) + shared.state.textinfo = description if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index feb876c6..3c1042ad 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -135,7 +135,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre params.process_caption_deepbooru = process_caption_deepbooru params.preprocess_txt_action = preprocess_txt_action - for index, imagefile in enumerate(tqdm.tqdm(files)): + pbar = tqdm.tqdm(files) + for index, imagefile in enumerate(pbar): params.subindex = 0 filename = os.path.join(src, imagefile) try: @@ -143,6 +144,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre except Exception: continue + description = f"Preprocessing [Image {index}/{len(files)}]" + pbar.set_description(description) + shared.state.textinfo = description + params.src = filename existing_caption = None diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 3866c154..b915b091 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -476,7 +476,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ epoch_num = embedding.step // steps_per_epoch epoch_step = embedding.step % steps_per_epoch - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" + pbar.set_description(description) + shared.state.textinfo = description if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. embedding_name_every = f'{embedding_name}-{steps_done}' -- cgit v1.2.3 From 4bd490727e156ff53107d53416d6b89be86f2a62 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 18:54:04 +0300 Subject: fix for an error caused by skipping initialization, for realsies this time: TypeError: expected str, bytes or os.PathLike object, not NoneType --- modules/sd_disable_initialization.py | 71 ++++++++++++++++-------------------- modules/sd_models.py | 1 + 2 files changed, 33 insertions(+), 39 deletions(-) (limited to 'modules') diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 088ac24b..c72d8efc 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -20,6 +20,19 @@ class DisableInitialization: ``` """ + def __init__(self): + self.replaced = [] + + def replace(self, obj, field, func): + original = getattr(obj, field, None) + if original is None: + return None + + self.replaced.append((obj, field, original)) + setattr(obj, field, func) + + return original + def __enter__(self): def do_nothing(*args, **kwargs): pass @@ -37,11 +50,14 @@ class DisableInitialization: def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs): # this file is always 404, prevent making request - if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json': - raise transformers.utils.hub.EntryNotFoundError + if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json': + return None try: - return original(url, *args, local_files_only=True, **kwargs) + res = original(url, *args, local_files_only=True, **kwargs) + if res is None: + res = original(url, *args, local_files_only=False, **kwargs) + return res except Exception as e: return original(url, *args, local_files_only=False, **kwargs) @@ -54,42 +70,19 @@ class DisableInitialization: def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs) - self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_ - self.init_no_grad_normal = torch.nn.init._no_grad_normal_ - self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_ - self.create_model_and_transforms = open_clip.create_model_and_transforms - self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained - self.transformers_modeling_utils_load_pretrained_model = getattr(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', None) - self.transformers_tokenization_utils_base_cached_file = getattr(transformers.tokenization_utils_base, 'cached_file', None) - self.transformers_configuration_utils_cached_file = getattr(transformers.configuration_utils, 'cached_file', None) - self.transformers_utils_hub_get_from_cache = getattr(transformers.utils.hub, 'get_from_cache', None) - - torch.nn.init.kaiming_uniform_ = do_nothing - torch.nn.init._no_grad_normal_ = do_nothing - torch.nn.init._no_grad_uniform_ = do_nothing - open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained - ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained - if self.transformers_modeling_utils_load_pretrained_model is not None: - transformers.modeling_utils.PreTrainedModel._load_pretrained_model = transformers_modeling_utils_load_pretrained_model - if self.transformers_tokenization_utils_base_cached_file is not None: - transformers.tokenization_utils_base.cached_file = transformers_tokenization_utils_base_cached_file - if self.transformers_configuration_utils_cached_file is not None: - transformers.configuration_utils.cached_file = transformers_configuration_utils_cached_file - if self.transformers_utils_hub_get_from_cache is not None: - transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache + self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing) + self.replace(torch.nn.init, '_no_grad_normal_', do_nothing) + self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing) + self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) + self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) + self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) + self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) + self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) + self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): - torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform - torch.nn.init._no_grad_normal_ = self.init_no_grad_normal - torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_ - open_clip.create_model_and_transforms = self.create_model_and_transforms - ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained - if self.transformers_modeling_utils_load_pretrained_model is not None: - transformers.modeling_utils.PreTrainedModel._load_pretrained_model = self.transformers_modeling_utils_load_pretrained_model - if self.transformers_tokenization_utils_base_cached_file is not None: - transformers.utils.hub.cached_file = self.transformers_tokenization_utils_base_cached_file - if self.transformers_configuration_utils_cached_file is not None: - transformers.utils.hub.cached_file = self.transformers_configuration_utils_cached_file - if self.transformers_utils_hub_get_from_cache is not None: - transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache + for obj, field, original in self.replaced: + setattr(obj, field, original) + + self.replaced.clear() diff --git a/modules/sd_models.py b/modules/sd_models.py index 084ba7fa..c466f273 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -334,6 +334,7 @@ def load_model(checkpoint_info=None): timer = Timer() sd_model = None + try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.3 From 0b8911d883118daa54f7735c5b753b5575d9f943 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 20:33:24 +0300 Subject: img2img UI rework: obsolete --gradio-img2img-tool --gradio-inpaint-tool and always show all tools each in own tab --- modules/img2img.py | 58 ++++++++++++++---------------- modules/shared.py | 4 +-- modules/ui.py | 103 +++++++++++++++++++++++++++-------------------------- style.css | 4 ++- 4 files changed, 84 insertions(+), 85 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index ca58b5d8..f62783c6 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,38 +59,34 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): - is_inpaint = mode == 1 - is_batch = mode == 2 - - if is_inpaint: - # Drawn mask - if mask_mode == 0: - is_mask_sketch = isinstance(init_img_with_mask, dict) - is_mask_paint = not is_mask_sketch - if is_mask_sketch: - # Sketch: mask iff. not transparent - image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] - alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') - mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') - else: - # Color-sketch: mask iff. painted over - image = init_img_with_mask - orig = init_img_with_mask_orig or init_img_with_mask - pred = np.any(np.array(image) != np.array(orig), axis=-1) - mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") - mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) - blur = ImageFilter.GaussianBlur(mask_blur) - image = Image.composite(image.filter(blur), orig, mask.filter(blur)) - - image = image.convert("RGB") - # Uploaded mask - else: - image = init_img_inpaint - mask = init_mask_inpaint - # No mask +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): + is_batch = mode == 5 + + if mode == 0: # img2img + image = init_img.convert("RGB") + mask = None + elif mode == 1: # img2img sketch + image = sketch.convert("RGB") + mask = None + elif mode == 2: # inpaint + image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] + alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') + mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') + image = image.convert("RGB") + elif mode == 3: # inpaint sketch + image = inpaint_color_sketch + orig = inpaint_color_sketch_orig or inpaint_color_sketch + pred = np.any(np.array(image) != np.array(orig), axis=-1) + mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") + mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) + blur = ImageFilter.GaussianBlur(mask_blur) + image = Image.composite(image.filter(blur), orig, mask.filter(blur)) + image = image.convert("RGB") + elif mode == 4: # inpaint upload mask + image = init_img_inpaint + mask = init_mask_inpaint else: - image = init_img + image = None mask = None # Use the EXIF orientation of photos taken by smartphones. diff --git a/modules/shared.py b/modules/shared.py index 264264a6..1c964237 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,8 +74,8 @@ parser.add_argument("--freeze-settings", action='store_true', help="disable edit parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) -parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") -parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it") +parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') +parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) diff --git a/modules/ui.py b/modules/ui.py index 82f5dd7c..e86a624b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -795,53 +795,67 @@ def create_ui(): with FormRow().style(equal_height=False): with gr.Column(variant='panel', elem_id="img2img_settings"): + with gr.Tabs(elem_id="mode_img2img"): + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480) - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state + with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) + inpaint_color_sketch_orig = gr.State(None) - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask") - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): + with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + with FormRow(): + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) + with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") @@ -900,20 +914,6 @@ def create_ui(): ] ) - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", @@ -924,11 +924,12 @@ def create_ui(): img2img_prompt_style, img2img_prompt_style2, init_img, + sketch, init_img_with_mask, - init_img_with_mask_orig, + inpaint_color_sketch, + inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, - mask_mode, steps, sampler_index, mask_blur, diff --git a/style.css b/style.css index ec5e4182..ffd6307f 100644 --- a/style.css +++ b/style.css @@ -557,7 +557,9 @@ canvas[key="mask"] { } #img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img, -img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img +#img2img_sketch, #img2img_sketch > .h-60, #img2img_sketch > .h-60 > div, #img2img_sketch > .h-60 > div > img, +#img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img, +#inpaint_sketch, #inpaint_sketch > .h-60, #inpaint_sketch > .h-60 > div, #inpaint_sketch > .h-60 > div > img { height: 480px !important; max-height: 480px !important; -- cgit v1.2.3 From d52a80f7f7da160c73afd067c8f1bf491391f994 Mon Sep 17 00:00:00 2001 From: Shondoit Date: Thu, 12 Jan 2023 09:22:29 +0100 Subject: Allow creation of zero vectors for TI --- modules/textual_inversion/textual_inversion.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b915b091..853246a6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -248,11 +248,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): with devices.autocast(): cond_model([""]) # will send cond model to GPU if lowvram/medvram is active - embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token) + #cond_model expects at least some text, so we provide '*' as backup. + embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) - for i in range(num_vectors_per_token): - vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + #Only copy if we provided an init_text, otherwise keep vectors as zeros + if init_text: + for i in range(num_vectors_per_token): + vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) -- cgit v1.2.3 From 88416ab5ff787eec3b9962b43b5e544bb75fbad6 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 12 Jan 2023 13:46:59 -0800 Subject: Fix extension parameters not being saved to last used parameters --- modules/processing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index f04a0e1e..ae04cab7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -531,16 +531,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: def infotext(iteration=0, position_in_batch=0): return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) - with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: - processed = Processed(p, [], p.seed, "") - file.write(processed.infotext(p, 0)) - if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() if p.scripts is not None: p.scripts.process(p) + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + processed = Processed(p, [], p.seed, "") + file.write(processed.infotext(p, 0)) + infotexts = [] output_images = [] -- cgit v1.2.3 From 6c88eaed4f5efca54a882eb1f8f30f01f350332a Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 12 Jan 2023 13:50:09 -0800 Subject: Add script callback for fixing infotext parameters --- modules/generation_parameters_copypaste.py | 3 ++- modules/script_callbacks.py | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 620aa606..593d99ef 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -7,7 +7,7 @@ from pathlib import Path import gradio as gr from modules.shared import script_path -from modules import shared, ui_tempdir +from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image @@ -298,6 +298,7 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None): prompt = file.read() params = parse_generation_parameters(prompt) + script_callbacks.infotext_pasted_callback(prompt, params) res = [] for output, key in paste_fields: diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 608c5300..a9e19236 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,7 +2,7 @@ import sys import traceback from collections import namedtuple import inspect -from typing import Optional +from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks @@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], + callbacks_infotext_pasted=[], callbacks_script_unloaded=[], ) @@ -172,6 +173,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') +def infotext_pasted_callback(infotext: str, params: Dict[str, Any]): + for c in callback_map['callbacks_infotext_pasted']: + try: + c.callback(infotext, params) + except Exception: + report_exception(c, 'infotext_pasted') + + def script_unloaded_callback(): for c in reversed(callback_map['callbacks_script_unloaded']): try: @@ -290,6 +299,15 @@ def on_image_grid(callback): add_callback(callback_map['callbacks_image_grid'], callback) +def on_infotext_pasted(callback): + """register a function to be called before applying an infotext. + The callback is called with two arguments: + - infotext: str - raw infotext. + - result: Dict[str, any] - parsed infotext parameters. + """ + add_callback(callback_map['callbacks_infotext_pasted'], callback) + + def on_script_unloaded(callback): """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that the script did should be reverted here""" -- cgit v1.2.3 From 0b262802b86a55c4f71faf377f2cb1aee2960b63 Mon Sep 17 00:00:00 2001 From: Josh R Date: Thu, 12 Jan 2023 17:31:05 -0800 Subject: add gradient settings to training settings log files --- modules/textual_inversion/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py index 8b1981d5..31e50b64 100644 --- a/modules/textual_inversion/logging.py +++ b/modules/textual_inversion/logging.py @@ -2,7 +2,7 @@ import datetime import json import os -saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"} +saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"} saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet -- cgit v1.2.3 From a176d89487d92f5a5b152401e5c424b34ff43b96 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 13 Jan 2023 14:32:15 +0300 Subject: print bucket sizes for training without resizing images #6620 fix an error when generating a picture with embedding in it --- modules/textual_inversion/dataset.py | 16 ++++++++++++++++ modules/textual_inversion/image_embedding.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 2 +- 3 files changed, 19 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index b47414f3..d31963d4 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -118,6 +118,12 @@ class PersonalizedBase(Dataset): self.gradient_step = min(gradient_step, self.length // self.batch_size) self.latent_sampling_method = latent_sampling_method + if len(groups) > 1: + print("Buckets:") + for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]): + print(f" {w}x{h}: {len(ids)}") + print() + def create_text(self, filename_text): text = random.choice(self.lines) tags = filename_text.split(',') @@ -140,8 +146,11 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) return entry + class GroupedBatchSampler(Sampler): def __init__(self, data_source: PersonalizedBase, batch_size: int): + super().__init__(data_source) + n = len(data_source) self.groups = data_source.groups self.len = n_batch = n // batch_size @@ -150,21 +159,28 @@ class GroupedBatchSampler(Sampler): self.n_rand_batches = nrb = n_batch - sum(self.base) self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] self.batch_size = batch_size + def __len__(self): return self.len + def __iter__(self): b = self.batch_size + for g in self.groups: shuffle(g) + batches = [] for g in self.groups: batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) for _ in range(self.n_rand_batches): rand_group = choices(self.groups, self.probs)[0] batches.append(choices(rand_group, k=b)) + shuffle(batches) + yield from batches + class PersonalizedDataLoader(DataLoader): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index ea653806..5593f88c 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -76,10 +76,10 @@ def insert_image_data_embed(image, data): next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h)) next_size = next_size + ((h*d)-(next_size % (h*d))) - data_np_low.resize(next_size) + data_np_low = np.resize(data_np_low, next_size) data_np_low = data_np_low.reshape((h, -1, d)) - data_np_high.resize(next_size) + data_np_high = np.resize(data_np_high, next_size) data_np_high = data_np_high.reshape((h, -1, d)) edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 853246a6..e23906ca 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -479,7 +479,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ epoch_num = embedding.step // steps_per_epoch epoch_step = embedding.step % steps_per_epoch - description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" + description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}" pbar.set_description(description) shared.state.textinfo = description if embedding_dir is not None and steps_done % save_embedding_every == 0: -- cgit v1.2.3