From ad4de819c43997f2666b5bad95301f5c37f9018e Mon Sep 17 00:00:00 2001
From: victorca25
Date: Sun, 9 Oct 2022 13:02:12 +0200
Subject: update ESRGAN architecture and model to support all ESRGAN models in
the DB, BSRGAN and real-ESRGAN models
---
modules/bsrgan_model.py | 76 -------
modules/bsrgan_model_arch.py | 102 ----------
modules/esrgam_model_arch.py | 80 --------
modules/esrgan_model.py | 190 ++++++++++++------
modules/esrgan_model_arch.py | 463 +++++++++++++++++++++++++++++++++++++++++++
5 files changed, 591 insertions(+), 320 deletions(-)
delete mode 100644 modules/bsrgan_model.py
delete mode 100644 modules/bsrgan_model_arch.py
delete mode 100644 modules/esrgam_model_arch.py
create mode 100644 modules/esrgan_model_arch.py
(limited to 'modules')
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py
deleted file mode 100644
index 737e1a76..00000000
--- a/modules/bsrgan_model.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import os.path
-import sys
-import traceback
-
-import PIL.Image
-import numpy as np
-import torch
-from basicsr.utils.download_util import load_file_from_url
-
-import modules.upscaler
-from modules import devices, modelloader
-from modules.bsrgan_model_arch import RRDBNet
-
-
-class UpscalerBSRGAN(modules.upscaler.Upscaler):
- def __init__(self, dirname):
- self.name = "BSRGAN"
- self.model_name = "BSRGAN 4x"
- self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
- self.user_path = dirname
- super().__init__()
- model_paths = self.find_models(ext_filter=[".pt", ".pth"])
- scalers = []
- if len(model_paths) == 0:
- scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
- scalers.append(scaler_data)
- for file in model_paths:
- if "http" in file:
- name = self.model_name
- else:
- name = modelloader.friendly_name(file)
- try:
- scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
- scalers.append(scaler_data)
- except Exception:
- print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- self.scalers = scalers
-
- def do_upscale(self, img: PIL.Image, selected_file):
- torch.cuda.empty_cache()
- model = self.load_model(selected_file)
- if model is None:
- return img
- model.to(devices.device_bsrgan)
- torch.cuda.empty_cache()
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_bsrgan)
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- torch.cuda.empty_cache()
- return PIL.Image.fromarray(output, 'RGB')
-
- def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
- progress=True)
- else:
- filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
- return None
- model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
- model.load_state_dict(torch.load(filename), strict=True)
- model.eval()
- for k, v in model.named_parameters():
- v.requires_grad = False
- return model
-
diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py
deleted file mode 100644
index cb4d1c13..00000000
--- a/modules/bsrgan_model_arch.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import functools
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.nn.init as init
-
-
-def initialize_weights(net_l, scale=1):
- if not isinstance(net_l, list):
- net_l = [net_l]
- for net in net_l:
- for m in net.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
- m.weight.data *= scale # for residual block
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
- m.weight.data *= scale
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.BatchNorm2d):
- init.constant_(m.weight, 1)
- init.constant_(m.bias.data, 0.0)
-
-
-def make_layer(block, n_layers):
- layers = []
- for _ in range(n_layers):
- layers.append(block())
- return nn.Sequential(*layers)
-
-
-class ResidualDenseBlock_5C(nn.Module):
- def __init__(self, nf=64, gc=32, bias=True):
- super(ResidualDenseBlock_5C, self).__init__()
- # gc: growth channel, i.e. intermediate channels
- self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
- self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
- self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
- self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
- self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- # initialization
- initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
-
- def forward(self, x):
- x1 = self.lrelu(self.conv1(x))
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- return x5 * 0.2 + x
-
-
-class RRDB(nn.Module):
- '''Residual in Residual Dense Block'''
-
- def __init__(self, nf, gc=32):
- super(RRDB, self).__init__()
- self.RDB1 = ResidualDenseBlock_5C(nf, gc)
- self.RDB2 = ResidualDenseBlock_5C(nf, gc)
- self.RDB3 = ResidualDenseBlock_5C(nf, gc)
-
- def forward(self, x):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- return out * 0.2 + x
-
-
-class RRDBNet(nn.Module):
- def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
- super(RRDBNet, self).__init__()
- RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
- self.sf = sf
-
- self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
- self.RRDB_trunk = make_layer(RRDB_block_f, nb)
- self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- #### upsampling
- self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- if self.sf==4:
- self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
-
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- def forward(self, x):
- fea = self.conv_first(x)
- trunk = self.trunk_conv(self.RRDB_trunk(fea))
- fea = fea + trunk
-
- fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
- if self.sf==4:
- fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
- out = self.conv_last(self.lrelu(self.HRconv(fea)))
-
- return out
\ No newline at end of file
diff --git a/modules/esrgam_model_arch.py b/modules/esrgam_model_arch.py
deleted file mode 100644
index e413d36e..00000000
--- a/modules/esrgam_model_arch.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# this file is taken from https://github.com/xinntao/ESRGAN
-
-import functools
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-def make_layer(block, n_layers):
- layers = []
- for _ in range(n_layers):
- layers.append(block())
- return nn.Sequential(*layers)
-
-
-class ResidualDenseBlock_5C(nn.Module):
- def __init__(self, nf=64, gc=32, bias=True):
- super(ResidualDenseBlock_5C, self).__init__()
- # gc: growth channel, i.e. intermediate channels
- self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
- self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
- self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
- self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
- self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- # initialization
- # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
-
- def forward(self, x):
- x1 = self.lrelu(self.conv1(x))
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- return x5 * 0.2 + x
-
-
-class RRDB(nn.Module):
- '''Residual in Residual Dense Block'''
-
- def __init__(self, nf, gc=32):
- super(RRDB, self).__init__()
- self.RDB1 = ResidualDenseBlock_5C(nf, gc)
- self.RDB2 = ResidualDenseBlock_5C(nf, gc)
- self.RDB3 = ResidualDenseBlock_5C(nf, gc)
-
- def forward(self, x):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- return out * 0.2 + x
-
-
-class RRDBNet(nn.Module):
- def __init__(self, in_nc, out_nc, nf, nb, gc=32):
- super(RRDBNet, self).__init__()
- RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
-
- self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
- self.RRDB_trunk = make_layer(RRDB_block_f, nb)
- self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- #### upsampling
- self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
-
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- def forward(self, x):
- fea = self.conv_first(x)
- trunk = self.trunk_conv(self.RRDB_trunk(fea))
- fea = fea + trunk
-
- fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
- fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
- out = self.conv_last(self.lrelu(self.HRconv(fea)))
-
- return out
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 3970e6e4..a49e2258 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -5,68 +5,115 @@ import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
-import modules.esrgam_model_arch as arch
+import modules.esrgan_model_arch as arch
from modules import shared, modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-def fix_model_layers(crt_model, pretrained_net):
- # this code is adapted from https://github.com/xinntao/ESRGAN
- if 'conv_first.weight' in pretrained_net:
- return pretrained_net
- if 'model.0.weight' not in pretrained_net:
- is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
- if is_realesrgan:
- raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
- else:
- raise Exception("The file is not a ESRGAN model.")
+def mod2normal(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if 'conv_first.weight' in state_dict:
+ crt_net = {}
+ items = []
+ for k, v in state_dict.items():
+ items.append(k)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
+ crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
+ crt_net['model.3.weight'] = state_dict['upconv1.weight']
+ crt_net['model.3.bias'] = state_dict['upconv1.bias']
+ crt_net['model.6.weight'] = state_dict['upconv2.weight']
+ crt_net['model.6.bias'] = state_dict['upconv2.bias']
+ crt_net['model.8.weight'] = state_dict['HRconv.weight']
+ crt_net['model.8.bias'] = state_dict['HRconv.bias']
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
+ state_dict = crt_net
+ return state_dict
+
+
+def resrgan2normal(state_dict, nb=23):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
+ crt_net = {}
+ items = []
+ for k, v in state_dict.items():
+ items.append(k)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if "rdb" in k:
+ ori_k = k.replace('body.', 'model.1.sub.')
+ ori_k = ori_k.replace('.rdb', '.RDB')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
+ crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
+ crt_net['model.3.weight'] = state_dict['conv_up1.weight']
+ crt_net['model.3.bias'] = state_dict['conv_up1.bias']
+ crt_net['model.6.weight'] = state_dict['conv_up2.weight']
+ crt_net['model.6.bias'] = state_dict['conv_up2.bias']
+ crt_net['model.8.weight'] = state_dict['conv_hr.weight']
+ crt_net['model.8.bias'] = state_dict['conv_hr.bias']
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
+ state_dict = crt_net
+ return state_dict
+
+
+def infer_params(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ scale2x = 0
+ scalemin = 6
+ n_uplayer = 0
+ plus = False
+
+ for block in list(state_dict):
+ parts = block.split(".")
+ n_parts = len(parts)
+ if n_parts == 5 and parts[2] == "sub":
+ nb = int(parts[3])
+ elif n_parts == 3:
+ part_num = int(parts[1])
+ if (part_num > scalemin
+ and parts[0] == "model"
+ and parts[2] == "weight"):
+ scale2x += 1
+ if part_num > n_uplayer:
+ n_uplayer = part_num
+ out_nc = state_dict[block].shape[0]
+ if not plus and "conv1x1" in block:
+ plus = True
+
+ nf = state_dict["model.0.weight"].shape[0]
+ in_nc = state_dict["model.0.weight"].shape[1]
+ out_nc = out_nc
+ scale = 2 ** scale2x
+
+ return in_nc, out_nc, nf, nb, plus, scale
- crt_net = crt_model.state_dict()
- load_net_clean = {}
- for k, v in pretrained_net.items():
- if k.startswith('module.'):
- load_net_clean[k[7:]] = v
- else:
- load_net_clean[k] = v
- pretrained_net = load_net_clean
-
- tbd = []
- for k, v in crt_net.items():
- tbd.append(k)
-
- # directly copy
- for k, v in crt_net.items():
- if k in pretrained_net and pretrained_net[k].size() == v.size():
- crt_net[k] = pretrained_net[k]
- tbd.remove(k)
-
- crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
- crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
-
- for k in tbd.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[k] = pretrained_net[ori_k]
- tbd.remove(k)
-
- crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
- crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
- crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
- crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
- crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
- crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
- crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
- crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
- crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
- crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
-
- return crt_net
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
@@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
- pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
- crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
+ state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
+
+ if "params_ema" in state_dict:
+ state_dict = state_dict["params_ema"]
+ elif "params" in state_dict:
+ state_dict = state_dict["params"]
+ num_conv = 16 if "realesr-animevideov3" in filename else 32
+ model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
+ nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
+ state_dict = resrgan2normal(state_dict, nb)
+ elif "conv_first.weight" in state_dict:
+ state_dict = mod2normal(state_dict)
+ elif "model.0.weight" not in state_dict:
+ raise Exception("The file is not a recognized ESRGAN model.")
+
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
- pretrained_net = fix_model_layers(crt_model, pretrained_net)
- crt_model.load_state_dict(pretrained_net)
- crt_model.eval()
+ model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
+ model.load_state_dict(state_dict)
+ model.eval()
- return crt_model
+ return model
def upscale_without_tiling(model, img):
img = np.array(img)
img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
new file mode 100644
index 00000000..bc9ceb2a
--- /dev/null
+++ b/modules/esrgan_model_arch.py
@@ -0,0 +1,463 @@
+# this file is adapted from https://github.com/victorca25/iNNfer
+
+import math
+import functools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+####################
+# RRDBNet Generator
+####################
+
+class RRDBNet(nn.Module):
+ def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
+ act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
+ finalact=None, gaussian_noise=False, plus=False):
+ super(RRDBNet, self).__init__()
+ n_upscale = int(math.log(upscale, 2))
+ if upscale == 3:
+ n_upscale = 1
+
+ self.resrgan_scale = 0
+ if in_nc % 16 == 0:
+ self.resrgan_scale = 1
+ elif in_nc != 4 and in_nc % 4 == 0:
+ self.resrgan_scale = 2
+
+ fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+ rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
+ LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
+
+ if upsample_mode == 'upconv':
+ upsample_block = upconv_block
+ elif upsample_mode == 'pixelshuffle':
+ upsample_block = pixelshuffle_block
+ else:
+ raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
+ if upscale == 3:
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
+ else:
+ upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
+ HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
+ HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+
+ outact = act(finalact) if finalact else None
+
+ self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
+ *upsampler, HR_conv0, HR_conv1, outact)
+
+ def forward(self, x, outm=None):
+ if self.resrgan_scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ elif self.resrgan_scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ else:
+ feat = x
+
+ return self.model(feat)
+
+
+class RRDB(nn.Module):
+ """
+ Residual in Residual Dense Block
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
+ """
+
+ def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
+ super(RRDB, self).__init__()
+ # This is for backwards compatibility with existing models
+ if nr == 3:
+ self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ else:
+ RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
+ self.RDBs = nn.Sequential(*RDB_list)
+
+ def forward(self, x):
+ if hasattr(self, 'RDB1'):
+ out = self.RDB1(x)
+ out = self.RDB2(out)
+ out = self.RDB3(out)
+ else:
+ out = self.RDBs(x)
+ return out * 0.2 + x
+
+
+class ResidualDenseBlock_5C(nn.Module):
+ """
+ Residual Dense Block
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
+ Modified options that can be used:
+ - "Partial Convolution based Padding" arXiv:1811.11718
+ - "Spectral normalization" arXiv:1802.05957
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
+ {Rakotonirina} and A. {Rasoanaivo}
+ """
+
+ def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
+ super(ResidualDenseBlock_5C, self).__init__()
+
+ self.noise = GaussianNoise() if gaussian_noise else None
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
+
+ self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ if mode == 'CNA':
+ last_act = None
+ else:
+ last_act = act_type
+ self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+
+ def forward(self, x):
+ x1 = self.conv1(x)
+ x2 = self.conv2(torch.cat((x, x1), 1))
+ if self.conv1x1:
+ x2 = x2 + self.conv1x1(x)
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
+ if self.conv1x1:
+ x4 = x4 + x2
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ if self.noise:
+ return self.noise(x5.mul(0.2) + x)
+ else:
+ return x5 * 0.2 + x
+
+
+####################
+# ESRGANplus
+####################
+
+class GaussianNoise(nn.Module):
+ def __init__(self, sigma=0.1, is_relative_detach=False):
+ super().__init__()
+ self.sigma = sigma
+ self.is_relative_detach = is_relative_detach
+ self.noise = torch.tensor(0, dtype=torch.float)
+
+ def forward(self, x):
+ if self.training and self.sigma != 0:
+ self.noise = self.noise.to(x.device)
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
+ x = x + sampled_noise
+ return x
+
+def conv1x1(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+####################
+# SRVGGNetCompact
+####################
+
+class SRVGGNetCompact(nn.Module):
+ """A compact VGG-style network structure for super-resolution.
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+ super(SRVGGNetCompact, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.num_out_ch = num_out_ch
+ self.num_feat = num_feat
+ self.num_conv = num_conv
+ self.upscale = upscale
+ self.act_type = act_type
+
+ self.body = nn.ModuleList()
+ # the first conv
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+ # the first activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the body structure
+ for _ in range(num_conv):
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+ # activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the last conv
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+ # upsample
+ self.upsampler = nn.PixelShuffle(upscale)
+
+ def forward(self, x):
+ out = x
+ for i in range(0, len(self.body)):
+ out = self.body[i](out)
+
+ out = self.upsampler(out)
+ # add the nearest upsampled image, so that the network learns the residual
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+ out += base
+ return out
+
+
+####################
+# Upsampler
+####################
+
+class Upsample(nn.Module):
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
+ The input data is assumed to be of the form
+ `minibatch x channels x [optional depth] x [optional height] x width`.
+ """
+
+ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ super(Upsample, self).__init__()
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
+ self.mode = mode
+ self.size = size
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
+
+ def extra_repr(self):
+ if self.scale_factor is not None:
+ info = 'scale_factor=' + str(self.scale_factor)
+ else:
+ info = 'size=' + str(self.size)
+ info += ', mode=' + self.mode
+ return info
+
+
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
+ """
+ Pixel shuffle layer
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
+ Neural Network, CVPR17)
+ """
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
+
+ n = norm(norm_type, out_nc) if norm_type else None
+ a = act(act_type) if act_type else None
+ return sequential(conv, pixel_shuffle, n, a)
+
+
+def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
+ """ Upconv layer """
+ upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
+ return sequential(upsample, conv)
+
+
+
+
+
+
+
+
+####################
+# Basic blocks
+####################
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+ Args:
+ basic_block (nn.module): nn.module class for basic block. (block)
+ num_basic_block (int): number of blocks. (n_layers)
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
+ """ activation helper """
+ act_type = act_type.lower()
+ if act_type == 'relu':
+ layer = nn.ReLU(inplace)
+ elif act_type in ('leakyrelu', 'lrelu'):
+ layer = nn.LeakyReLU(neg_slope, inplace)
+ elif act_type == 'prelu':
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
+ elif act_type == 'tanh': # [-1, 1] range output
+ layer = nn.Tanh()
+ elif act_type == 'sigmoid': # [0, 1] range output
+ layer = nn.Sigmoid()
+ else:
+ raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
+ return layer
+
+
+class Identity(nn.Module):
+ def __init__(self, *kwargs):
+ super(Identity, self).__init__()
+
+ def forward(self, x, *kwargs):
+ return x
+
+
+def norm(norm_type, nc):
+ """ Return a normalization layer """
+ norm_type = norm_type.lower()
+ if norm_type == 'batch':
+ layer = nn.BatchNorm2d(nc, affine=True)
+ elif norm_type == 'instance':
+ layer = nn.InstanceNorm2d(nc, affine=False)
+ elif norm_type == 'none':
+ def norm_layer(x): return Identity()
+ else:
+ raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
+ return layer
+
+
+def pad(pad_type, padding):
+ """ padding layer helper """
+ pad_type = pad_type.lower()
+ if padding == 0:
+ return None
+ if pad_type == 'reflect':
+ layer = nn.ReflectionPad2d(padding)
+ elif pad_type == 'replicate':
+ layer = nn.ReplicationPad2d(padding)
+ elif pad_type == 'zero':
+ layer = nn.ZeroPad2d(padding)
+ else:
+ raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
+ return layer
+
+
+def get_valid_padding(kernel_size, dilation):
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
+ padding = (kernel_size - 1) // 2
+ return padding
+
+
+class ShortcutBlock(nn.Module):
+ """ Elementwise sum the output of a submodule to its input """
+ def __init__(self, submodule):
+ super(ShortcutBlock, self).__init__()
+ self.sub = submodule
+
+ def forward(self, x):
+ output = x + self.sub(x)
+ return output
+
+ def __repr__(self):
+ return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
+
+
+def sequential(*args):
+ """ Flatten Sequential. It unwraps nn.Sequential. """
+ if len(args) == 1:
+ if isinstance(args[0], OrderedDict):
+ raise NotImplementedError('sequential does not support OrderedDict input.')
+ return args[0] # No sequential is needed.
+ modules = []
+ for module in args:
+ if isinstance(module, nn.Sequential):
+ for submodule in module.children():
+ modules.append(submodule)
+ elif isinstance(module, nn.Module):
+ modules.append(module)
+ return nn.Sequential(*modules)
+
+
+def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False):
+ """ Conv layer with padding, normalization, activation """
+ assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
+ padding = get_valid_padding(kernel_size, dilation)
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
+ padding = padding if pad_type == 'zero' else 0
+
+ if convtype=='PartialConv2D':
+ c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='DeformConv2D':
+ c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='Conv3D':
+ c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ else:
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+
+ if spectral_norm:
+ c = nn.utils.spectral_norm(c)
+
+ a = act(act_type) if act_type else None
+ if 'CNA' in mode:
+ n = norm(norm_type, out_nc) if norm_type else None
+ return sequential(p, c, n, a)
+ elif mode == 'NAC':
+ if norm_type is None and act_type is not None:
+ a = act(act_type, inplace=False)
+ n = norm(norm_type, in_nc) if norm_type else None
+ return sequential(n, a, p, c)
--
cgit v1.2.3
From bb57f30c2de46cfca5419ad01738a41705f96cc3 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Fri, 14 Oct 2022 10:56:41 +0200
Subject: init
---
modules/processing.py | 17 +++++-
modules/sd_hijack.py | 80 +++++++++++++++++++++++++-
modules/shared.py | 5 ++
modules/textual_inversion/dataset.py | 2 +-
modules/textual_inversion/textual_inversion.py | 35 +++++++----
modules/txt2img.py | 11 +++-
modules/ui.py | 59 ++++++++++++-------
7 files changed, 171 insertions(+), 38 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index d5172f00..9a033759 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -316,11 +316,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
-def process_images(p: StableDiffusionProcessing) -> Processed:
+def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0,
+ aesthetic_imgs=None,aesthetic_slerp=False) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
+ aesthetic_lr = float(aesthetic_lr)
+ aesthetic_weight = float(aesthetic_weight)
+ aesthetic_steps = int(aesthetic_steps)
+
if type(p.prompt) == list:
- assert(len(p.prompt) > 0)
+ assert (len(p.prompt) > 0)
else:
assert p.prompt is not None
@@ -394,7 +399,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
- uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
+ if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
+ shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0)
+ uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt],
+ p.steps)
+ if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
+ shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight,
+ aesthetic_steps, aesthetic_imgs,aesthetic_slerp)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index c81722a0..6d5196fe 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -9,11 +9,14 @@ from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
-from modules.shared import opts, device, cmd_opts
+from modules.shared import opts, device, cmd_opts, aesthetic_embeddings
from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+from transformers import CLIPVisionModel, CLIPModel
+import torch.optim as optim
+import copy
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
@@ -109,13 +112,29 @@ class StableDiffusionModelHijack:
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
+def slerp(low, high, val):
+ low_norm = low/torch.norm(low, dim=1, keepdim=True)
+ high_norm = high/torch.norm(high, dim=1, keepdim=True)
+ omega = torch.acos((low_norm*high_norm).sum(1))
+ so = torch.sin(omega)
+ res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+ return res
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
+ self.clipModel = CLIPModel.from_pretrained(
+ self.wrapped.transformer.name_or_path
+ )
+ del self.clipModel.vision_model
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
+ # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval()
+ self.image_embs_name = None
+ self.image_embs = None
+ self.load_image_embs(None)
+
self.token_mults = {}
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0]
@@ -136,6 +155,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
+ def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None,
+ aesthetic_slerp=True):
+ self.slerp = aesthetic_slerp
+ self.aesthetic_lr = aesthetic_lr
+ self.aesthetic_weight = aesthetic_weight
+ self.aesthetic_steps = aesthetic_steps
+ self.load_image_embs(image_embs_name)
+
+ def load_image_embs(self, image_embs_name):
+ if image_embs_name is None or len(image_embs_name) == 0:
+ image_embs_name = None
+ if image_embs_name is not None and self.image_embs_name != image_embs_name:
+ self.image_embs_name = image_embs_name
+ self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device)
+ self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
+ self.image_embs.requires_grad_(False)
+
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_end = self.wrapped.tokenizer.eos_token_id
@@ -333,7 +369,47 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
+
+ if len(text[
+ 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None:
+ if not opts.use_old_emphasis_implementation:
+ remade_batch_tokens = [
+ [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
+ remade_batch_tokens]
+
+ tokens = torch.asarray(remade_batch_tokens).to(device)
+ with torch.enable_grad():
+ model = copy.deepcopy(self.clipModel).to(device)
+ model.requires_grad_(True)
+
+ # We optimize the model to maximize the similarity
+ optimizer = optim.Adam(
+ model.text_model.parameters(), lr=self.aesthetic_lr
+ )
+
+ for i in range(self.aesthetic_steps):
+ text_embs = model.get_text_features(input_ids=tokens)
+ text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
+ sim = text_embs @ self.image_embs.T
+ loss = -sim
+ optimizer.zero_grad()
+ loss.mean().backward()
+ optimizer.step()
+
+ zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+ if opts.CLIP_stop_at_last_layers > 1:
+ zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
+ zn = model.text_model.final_layer_norm(zn)
+ else:
+ zn = zn.last_hidden_state
+ model.cpu()
+ del model
+
+ if self.slerp:
+ z = slerp(z, zn, self.aesthetic_weight)
+ else:
+ z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
+
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
diff --git a/modules/shared.py b/modules/shared.py
index 5901e605..cf13a10d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -30,6 +30,8 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(script_path, 'aesthetic_embeddings'),
+ help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
@@ -90,6 +92,9 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
+aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
+ os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
+
def reload_hypernetworks():
global hypernetworks
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 67e90afe..59b2b021 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -48,7 +48,7 @@ class PersonalizedBase(Dataset):
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
try:
- image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC)
except Exception:
continue
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fa0e33a2..b12a8e6d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -172,7 +172,15 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
+def batched(dataset, total, n=1):
+ for ndx in range(0, total, n):
+ yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
+
+
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps,
+ create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding,
+ preview_image_prompt, batch_size=1,
+ gradient_accumulation=1):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -204,7 +212,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
+ height=training_height,
+ repeats=shared.opts.training_image_repeats_per_epoch,
+ placeholder_token=embedding_name, model=shared.sd_model,
+ device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -223,7 +235,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
- pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ pbar = tqdm.tqdm(enumerate(batched(ds, steps - ititial_step, batch_size)), total=steps - ititial_step)
for i, entry in pbar:
embedding.step = i + ititial_step
@@ -235,17 +247,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
break
with torch.autocast("cuda"):
- c = cond_model([entry.cond_text])
+ c = cond_model([e.cond_text for e in entry])
+
+ x = torch.stack([e.latent for e in entry]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
- x = entry.latent.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
- optimizer.zero_grad()
loss.backward()
- optimizer.step()
+ if ((i + 1) % gradient_accumulation == 0) or (i + 1 == steps - ititial_step):
+ optimizer.step()
+ optimizer.zero_grad()
+
epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
@@ -259,7 +274,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
- preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry[0].cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
@@ -305,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entry[-1].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
diff --git a/modules/txt2img.py b/modules/txt2img.py
index e985242b..78342024 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -6,7 +6,14 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int,
+ restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int,
+ subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool,
+ height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float,
+ aesthetic_lr=0,
+ aesthetic_weight=0, aesthetic_steps=0,
+ aesthetic_imgs=None,
+ aesthetic_slerp=False, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -40,7 +47,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
- processed = process_images(p)
+ processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp)
shared.total_tqdm.clear()
diff --git a/modules/ui.py b/modules/ui.py
index 220fb80b..d961d126 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -24,7 +24,8 @@ import gradio.routes
from modules import sd_hijack
from modules.paths import script_path
-from modules.shared import opts, cmd_opts
+from modules.shared import opts, cmd_opts,aesthetic_embeddings
+
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
import modules.shared as shared
@@ -534,6 +535,14 @@ def create_ui(wrap_gradio_gpu_call):
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ with gr.Group():
+ aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7)
+ aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50)
+
+ aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
+ aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
+
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False)
@@ -586,25 +595,30 @@ def create_ui(wrap_gradio_gpu_call):
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
_js="submit",
inputs=[
- txt2img_prompt,
- txt2img_negative_prompt,
- txt2img_prompt_style,
- txt2img_prompt_style2,
- steps,
- sampler_index,
- restore_faces,
- tiling,
- batch_count,
- batch_size,
- cfg_scale,
- seed,
- subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
- height,
- width,
- enable_hr,
- scale_latent,
- denoising_strength,
- ] + custom_inputs,
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ txt2img_prompt_style,
+ txt2img_prompt_style2,
+ steps,
+ sampler_index,
+ restore_faces,
+ tiling,
+ batch_count,
+ batch_size,
+ cfg_scale,
+ seed,
+ subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
+ height,
+ width,
+ enable_hr,
+ scale_latent,
+ denoising_strength,
+ aesthetic_lr,
+ aesthetic_weight,
+ aesthetic_steps,
+ aesthetic_imgs,
+ aesthetic_slerp
+ ] + custom_inputs,
outputs=[
txt2img_gallery,
generation_info,
@@ -1097,6 +1111,9 @@ def create_ui(wrap_gradio_gpu_call):
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ batch_size = gr.Slider(minimum=1, maximum=64, step=1, label="Batch Size", value=4)
+ gradient_accumulation = gr.Slider(minimum=1, maximum=256, step=1, label="Gradient accumulation",
+ value=1)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
@@ -1180,6 +1197,8 @@ def create_ui(wrap_gradio_gpu_call):
template_file,
save_image_with_stored_embedding,
preview_image_prompt,
+ batch_size,
+ gradient_accumulation
],
outputs=[
ti_output,
--
cgit v1.2.3
From 37d7ffb415cd8c69b3c0bb5f61844dde0b169f78 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 15:59:37 +0200
Subject: fix to tokens lenght, addend embs generator, add new features to edit
the embedding before the generation using text
---
modules/aesthetic_clip.py | 78 ++++++++++++++++++++++++
modules/processing.py | 148 +++++++++++++++++++++++++++++++---------------
modules/sd_hijack.py | 111 ++++++++++++++++++++++------------
modules/shared.py | 4 ++
modules/txt2img.py | 10 +++-
modules/ui.py | 47 ++++++++++++---
6 files changed, 302 insertions(+), 96 deletions(-)
create mode 100644 modules/aesthetic_clip.py
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
new file mode 100644
index 00000000..f15cfd47
--- /dev/null
+++ b/modules/aesthetic_clip.py
@@ -0,0 +1,78 @@
+import itertools
+import os
+from pathlib import Path
+import html
+import gc
+
+import gradio as gr
+import torch
+from PIL import Image
+from modules import shared
+from modules.shared import device, aesthetic_embeddings
+from transformers import CLIPModel, CLIPProcessor
+
+from tqdm.auto import tqdm
+
+
+def get_all_images_in_folder(folder):
+ return [os.path.join(folder, f) for f in os.listdir(folder) if
+ os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
+
+
+def check_is_valid_image_file(filename):
+ return filename.lower().endswith(('.png', '.jpg', '.jpeg'))
+
+
+def batched(dataset, total, n=1):
+ for ndx in range(0, total, n):
+ yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
+
+
+def iter_to_batched(iterable, n=1):
+ it = iter(iterable)
+ while True:
+ chunk = tuple(itertools.islice(it, n))
+ if not chunk:
+ return
+ yield chunk
+
+
+def generate_imgs_embd(name, folder, batch_size):
+ # clipModel = CLIPModel.from_pretrained(
+ # shared.sd_model.cond_stage_model.clipModel.name_or_path
+ # )
+ model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device)
+ processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path)
+
+ with torch.no_grad():
+ embs = []
+ for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
+ desc=f"Generating embeddings for {name}"):
+ if shared.state.interrupted:
+ break
+ inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
+ outputs = model.get_image_features(**inputs).cpu()
+ embs.append(torch.clone(outputs))
+ inputs.to("cpu")
+ del inputs, outputs
+
+ embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
+
+ # The generated embedding will be located here
+ path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
+ torch.save(embs, path)
+
+ model = model.cpu()
+ del model
+ del processor
+ del embs
+ gc.collect()
+ torch.cuda.empty_cache()
+ res = f"""
+ Done generating embedding for {name}!
+ Hypernetwork saved to {html.escape(path)}
+ """
+ shared.update_aesthetic_embeddings()
+ return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
+ value=sorted(aesthetic_embeddings.keys())[0] if len(
+ aesthetic_embeddings) > 0 else None), res, ""
diff --git a/modules/processing.py b/modules/processing.py
index 9a033759..ab68d63a 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -20,7 +20,6 @@ import modules.images as images
import modules.styles
import logging
-
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
@@ -52,8 +51,13 @@ def get_correct_sampler(p):
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img
+
class StableDiffusionProcessing:
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1,
+ subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True,
+ sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512,
+ restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False,
+ extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -104,7 +108,8 @@ class StableDiffusionProcessing:
class Processed:
- def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None,
+ all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
@@ -141,7 +146,8 @@ class Processed:
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
- self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
+ self.subseed = int(
+ self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed]
@@ -181,39 +187,43 @@ class Processed:
return json.dumps(obj)
- def infotext(self, p: StableDiffusionProcessing, index):
- return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
+ def infotext(self, p: StableDiffusionProcessing, index):
+ return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[],
+ position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
- low_norm = low/torch.norm(low, dim=1, keepdim=True)
- high_norm = high/torch.norm(high, dim=1, keepdim=True)
- dot = (low_norm*high_norm).sum(1)
+ low_norm = low / torch.norm(low, dim=1, keepdim=True)
+ high_norm = high / torch.norm(high, dim=1, keepdim=True)
+ dot = (low_norm * high_norm).sum(1)
if dot.mean() > 0.9995:
return low * val + high * (1 - val)
omega = torch.acos(dot)
so = torch.sin(omega)
- res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+ res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
-def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
+def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0,
+ p=None):
xs = []
# if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
+ if p is not None and p.sampler is not None and (
+ len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
for i, seed in enumerate(seeds):
- noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
+ noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (
+ shape[0], seed_resize_from_h // 8, seed_resize_from_w // 8)
subnoise = None
if subseeds is not None:
@@ -241,7 +251,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
dx = max(-dx, 0)
dy = max(-dy, 0)
- x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
+ x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
noise = x
if sampler_noises is not None:
@@ -293,14 +303,20 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
- "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
- "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
+ "Model hash": getattr(p, 'sd_model_hash',
+ None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
+ "Model": (
+ None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(
+ ',', '').replace(':', '')),
+ "Hypernet": (
+ None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(
+ ':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
- "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
+ "Seed resize from": (
+ None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
@@ -309,7 +325,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params)
- generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
+ generation_params_text = ", ".join(
+ [k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
@@ -317,7 +334,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0,
- aesthetic_imgs=None,aesthetic_slerp=False) -> Processed:
+ aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="",
+ aesthetic_slerp_angle=0.15,
+ aesthetic_text_negative=False) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
aesthetic_lr = float(aesthetic_lr)
@@ -385,7 +404,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
for n in range(p.n_iter):
if state.skipped:
state.skipped = False
-
+
if state.interrupted:
break
@@ -396,16 +415,19 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if (len(prompts) == 0):
break
- #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
- #c = p.sd_model.get_learned_conditioning(prompts)
+ # uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
+ # c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
- shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0)
+ shared.sd_model.cond_stage_model.set_aesthetic_params()
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt],
p.steps)
if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight,
- aesthetic_steps, aesthetic_imgs,aesthetic_slerp)
+ aesthetic_steps, aesthetic_imgs,
+ aesthetic_slerp, aesthetic_imgs_text,
+ aesthetic_slerp_angle,
+ aesthetic_text_negative)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
@@ -413,13 +435,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
comments[comment] = 1
if p.n_iter > 1:
- shared.state.job = f"Batch {n+1} out of {p.n_iter}"
+ shared.state.job = f"Batch {n + 1} out of {p.n_iter}"
with devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds,
+ subseed_strength=p.subseed_strength)
if state.interrupted or state.skipped:
-
# if we are interrupted, sample returns just noise
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
@@ -445,7 +467,9 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
- images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i],
+ opts.samples_format, info=infotext(n, i), p=p,
+ suffix="-before-face-restoration")
devices.torch_gc()
@@ -456,7 +480,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format,
+ info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
if p.overlay_images is not None and i < len(p.overlay_images):
@@ -474,7 +499,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
image = image.convert('RGB')
if opts.samples_save and not p.do_not_save_samples:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format,
+ info=infotext(n, i), p=p)
text = infotext(n, i)
infotexts.append(text)
@@ -482,7 +508,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
image.info["parameters"] = text
output_images.append(image)
- del x_samples_ddim
+ del x_samples_ddim
devices.torch_gc()
@@ -504,10 +530,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format,
+ info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+ return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]),
+ subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds,
+ index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
@@ -543,25 +572,34 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds,
+ subseeds=subseeds, subseed_strength=self.subseed_strength,
+ seed_resize_from_h=self.seed_resize_from_h,
+ seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
return samples
- x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds,
+ subseeds=subseeds, subseed_strength=self.subseed_strength,
+ seed_resize_from_h=self.seed_resize_from_h,
+ seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
- samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
+ samples = samples[:, :, truncate_y // 2:samples.shape[2] - truncate_y // 2,
+ truncate_x // 2:samples.shape[3] - truncate_x // 2]
if self.scale_latent:
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f),
+ mode="bilinear")
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
- decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
+ decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width),
+ mode="bilinear")
else:
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@@ -585,13 +623,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
- noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds,
+ subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h,
+ seed_resize_from_w=self.seed_resize_from_w, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning,
+ steps=self.steps)
return samples
@@ -599,7 +640,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
+ def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4,
+ inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0,
+ **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@@ -607,7 +650,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.denoising_strength: float = denoising_strength
self.init_latent = None
self.image_mask = mask
- #self.image_unblurred_mask = None
+ # self.image_unblurred_mask = None
self.latent_mask = None
self.mask_for_overlay = None
self.mask_blur = mask_blur
@@ -619,7 +662,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = None
def init(self, all_prompts, all_seeds, all_subseeds):
- self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index,
+ self.sd_model)
crop_region = None
if self.image_mask is not None:
@@ -628,7 +672,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_mask_invert:
self.image_mask = ImageOps.invert(self.image_mask)
- #self.image_unblurred_mask = self.image_mask
+ # self.image_unblurred_mask = self.image_mask
if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
@@ -642,7 +686,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask = mask.crop(crop_region)
self.image_mask = images.resize_image(2, mask, self.width, self.height)
- self.paste_to = (x1, y1, x2-x1, y2-y1)
+ self.paste_to = (x1, y1, x2 - x1, y2 - y1)
else:
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
np_mask = np.array(self.image_mask)
@@ -665,7 +709,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.image_mask is not None:
image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
+ image_masked.paste(image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image_masked.convert('RGBA'))
@@ -714,12 +759,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# this needs to be fixed to be done in sample() using actual seeds for batches
if self.inpainting_fill == 2:
- self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:],
+ all_seeds[
+ 0:self.init_latent.shape[
+ 0]]) * self.nmask
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds,
+ subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h,
+ seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 6d5196fe..192883b2 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -14,7 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
-from transformers import CLIPVisionModel, CLIPModel
+from tqdm import trange
+from transformers import CLIPVisionModel, CLIPModel, CLIPTokenizer
import torch.optim as optim
import copy
@@ -22,21 +23,25 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (
+ 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
+ elif not cmd_opts.disable_opt_split_attention and (
+ cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps':
- print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
+ print(
+ "The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
else:
@@ -112,14 +117,16 @@ class StableDiffusionModelHijack:
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
+
def slerp(low, high, val):
- low_norm = low/torch.norm(low, dim=1, keepdim=True)
- high_norm = high/torch.norm(high, dim=1, keepdim=True)
- omega = torch.acos((low_norm*high_norm).sum(1))
+ low_norm = low / torch.norm(low, dim=1, keepdim=True)
+ high_norm = high / torch.norm(high, dim=1, keepdim=True)
+ omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
- res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+ res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
+
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
@@ -128,6 +135,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.wrapped.transformer.name_or_path
)
del self.clipModel.vision_model
+ self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path)
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
# self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval()
@@ -139,7 +147,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0]
- tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if
+ '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
@@ -155,8 +164,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
- def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None,
- aesthetic_slerp=True):
+ def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
+ aesthetic_slerp=True, aesthetic_imgs_text="",
+ aesthetic_slerp_angle=0.15,
+ aesthetic_text_negative=False):
+ self.aesthetic_imgs_text = aesthetic_imgs_text
+ self.aesthetic_slerp_angle = aesthetic_slerp_angle
+ self.aesthetic_text_negative = aesthetic_text_negative
self.slerp = aesthetic_slerp
self.aesthetic_lr = aesthetic_lr
self.aesthetic_weight = aesthetic_weight
@@ -180,7 +194,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else:
parsed = [[line, 1.0]]
- tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
+ tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)[
+ "input_ids"]
fixes = []
remade_tokens = []
@@ -196,18 +211,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if token == self.comma_token:
last_comma = len(remade_tokens)
- elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
+ elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens),
+ 1) % 75 == 0 and last_comma != -1 and len(
+ remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
-
+
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
-
+
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
@@ -248,7 +265,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
- remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms,
+ hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
@@ -259,7 +277,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
@@ -289,7 +306,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens,
+ i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
@@ -312,11 +330,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+ hijack_comments.append(
+ f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
@@ -326,23 +345,26 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
+
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(
+ text)
else:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(
+ text)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
+ self.hijack.comments.append(
+ "Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
-
+
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
@@ -356,7 +378,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
-
+
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
@@ -378,19 +400,30 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens).to(device)
+
+ model = copy.deepcopy(self.clipModel).to(device)
+ model.requires_grad_(True)
+ if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
+ text_embs_2 = model.get_text_features(
+ **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
+ if self.aesthetic_text_negative:
+ text_embs_2 = self.image_embs - text_embs_2
+ text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
+ img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
+ else:
+ img_embs = self.image_embs
+
with torch.enable_grad():
- model = copy.deepcopy(self.clipModel).to(device)
- model.requires_grad_(True)
# We optimize the model to maximize the similarity
optimizer = optim.Adam(
model.text_model.parameters(), lr=self.aesthetic_lr
)
- for i in range(self.aesthetic_steps):
+ for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
text_embs = model.get_text_features(input_ids=tokens)
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
- sim = text_embs @ self.image_embs.T
+ sim = text_embs @ img_embs.T
loss = -sim
optimizer.zero_grad()
loss.mean().backward()
@@ -405,6 +438,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
model.cpu()
del model
+ zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1)
if self.slerp:
z = slerp(z, zn, self.aesthetic_weight)
else:
@@ -413,15 +447,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
-
+
return z
-
-
+
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
+ remade_batch_tokens = [
+ [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
+ remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
-
+
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
@@ -461,8 +496,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
- emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
- tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor)
diff --git a/modules/shared.py b/modules/shared.py
index cf13a10d..7cd608ca 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -95,6 +95,10 @@ loaded_hypernetwork = None
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
+def update_aesthetic_embeddings():
+ global aesthetic_embeddings
+ aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
+ os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
def reload_hypernetworks():
global hypernetworks
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 78342024..eedcdfe0 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -13,7 +13,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
aesthetic_lr=0,
aesthetic_weight=0, aesthetic_steps=0,
aesthetic_imgs=None,
- aesthetic_slerp=False, *args):
+ aesthetic_slerp=False,
+ aesthetic_imgs_text="",
+ aesthetic_slerp_angle=0.15,
+ aesthetic_text_negative=False,
+ *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -47,7 +51,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
- processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp)
+ processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text,
+ aesthetic_slerp_angle,
+ aesthetic_text_negative)
shared.total_tqdm.clear()
diff --git a/modules/ui.py b/modules/ui.py
index d961d126..e98e2113 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -41,6 +41,7 @@ from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui
+import modules.aesthetic_clip
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -449,7 +450,7 @@ def create_toprow(is_img2img):
with gr.Row():
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
with gr.Column(scale=1, elem_id="roll_col"):
- sh = gr.Button(elem_id="sh", visible=True)
+ sh = gr.Button(elem_id="sh", visible=True)
with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
@@ -536,9 +537,13 @@ def create_ui(wrap_gradio_gpu_call):
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Group():
- aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
- aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7)
- aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50)
+ aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001")
+ aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
+ aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5)
+ with gr.Row():
+ aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
+ aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
+ aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
@@ -617,7 +622,10 @@ def create_ui(wrap_gradio_gpu_call):
aesthetic_weight,
aesthetic_steps,
aesthetic_imgs,
- aesthetic_slerp
+ aesthetic_slerp,
+ aesthetic_imgs_text,
+ aesthetic_slerp_angle,
+ aesthetic_text_negative
] + custom_inputs,
outputs=[
txt2img_gallery,
@@ -721,7 +729,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
- inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32)
+ inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=1024, step=4, value=32)
with gr.TabItem('Batch img2img', id='batch'):
hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
@@ -1071,6 +1079,17 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary')
+ with gr.Tab(label="Create images embedding"):
+ new_embedding_name_ae = gr.Textbox(label="Name")
+ process_src_ae = gr.Textbox(label='Source directory')
+ batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')
+
with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
@@ -1139,7 +1158,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
new_embedding_name,
- initialization_text,
+ process_src,
nvpt,
],
outputs=[
@@ -1149,6 +1168,20 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ create_embedding_ae.click(
+ fn=modules.aesthetic_clip.generate_imgs_embd,
+ inputs=[
+ new_embedding_name_ae,
+ process_src_ae,
+ batch_ae
+ ],
+ outputs=[
+ aesthetic_imgs,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
create_hypernetwork.click(
fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[
--
cgit v1.2.3
From 4387e4fe6479c08f7bc7e42924c3a1093e3a1872 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 18:39:29 +0200
Subject: Update modules/ui.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Víctor Gallego
---
modules/ui.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index d0696101..5bb961b2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -599,7 +599,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001")
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
- aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5)
+ aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
+
with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
--
cgit v1.2.3
From 9b7705e0573bddde26df4575c71f994d73a4d519 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 18:40:34 +0200
Subject: Update modules/aesthetic_clip.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Víctor Gallego
---
modules/aesthetic_clip.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
index f15cfd47..bcf2b073 100644
--- a/modules/aesthetic_clip.py
+++ b/modules/aesthetic_clip.py
@@ -70,7 +70,7 @@ def generate_imgs_embd(name, folder, batch_size):
torch.cuda.empty_cache()
res = f"""
Done generating embedding for {name}!
- Hypernetwork saved to {html.escape(path)}
+ Aesthetic embedding saved to {html.escape(path)}
"""
shared.update_aesthetic_embeddings()
return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
--
cgit v1.2.3
From 0d4f5db235357aeb4c7a8738179ba33aaf5a6b75 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 18:40:58 +0200
Subject: Update modules/ui.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Víctor Gallego
---
modules/ui.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 5bb961b2..25eba548 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -597,7 +597,8 @@ def create_ui(wrap_gradio_gpu_call):
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Group():
- aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001")
+ aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001")
+
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
--
cgit v1.2.3
From ad9bc604a8fadcfebe72be37f66cec51e7e87fb5 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 18:41:18 +0200
Subject: Update modules/ui.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Víctor Gallego
---
modules/ui.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 25eba548..3b28b69c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -607,7 +607,8 @@ def create_ui(wrap_gradio_gpu_call):
aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
- aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
+ aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Aesthetic imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
+
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
with gr.Row():
--
cgit v1.2.3
From 3f5c3b981e46c16bb10948d012575b25170efb3b Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 18:41:46 +0200
Subject: Update modules/ui.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Víctor Gallego
---
modules/ui.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 3b28b69c..1f6fcdc9 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1190,7 +1190,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary')
- with gr.Tab(label="Create images embedding"):
+ with gr.Tab(label="Create aesthetic images embedding"):
+
new_embedding_name_ae = gr.Textbox(label="Name")
process_src_ae = gr.Textbox(label='Source directory')
batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
--
cgit v1.2.3
From 3d21684ee30ca5734126b8d08c05b3a0f513fe75 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sun, 16 Oct 2022 00:01:00 +0200
Subject: Add support to other img format, fixed dropbox update
---
modules/aesthetic_clip.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
index bcf2b073..68264284 100644
--- a/modules/aesthetic_clip.py
+++ b/modules/aesthetic_clip.py
@@ -8,7 +8,7 @@ import gradio as gr
import torch
from PIL import Image
from modules import shared
-from modules.shared import device, aesthetic_embeddings
+from modules.shared import device
from transformers import CLIPModel, CLIPProcessor
from tqdm.auto import tqdm
@@ -20,7 +20,7 @@ def get_all_images_in_folder(folder):
def check_is_valid_image_file(filename):
- return filename.lower().endswith(('.png', '.jpg', '.jpeg'))
+ return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))
def batched(dataset, total, n=1):
@@ -73,6 +73,6 @@ def generate_imgs_embd(name, folder, batch_size):
Aesthetic embedding saved to {html.escape(path)}
"""
shared.update_aesthetic_embeddings()
- return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
- value=sorted(aesthetic_embeddings.keys())[0] if len(
- aesthetic_embeddings) > 0 else None), res, ""
+ return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
+ value=sorted(shared.aesthetic_embeddings.keys())[0] if len(
+ shared.aesthetic_embeddings) > 0 else None), res, ""
--
cgit v1.2.3
From 9325c85f780c569d1823e422eaf51b2e497e0d3e Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sun, 16 Oct 2022 00:23:47 +0200
Subject: fixed dropbox update
---
modules/sd_hijack.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 192883b2..491312b4 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -9,7 +9,7 @@ from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
-from modules.shared import opts, device, cmd_opts, aesthetic_embeddings
+from modules.shared import opts, device, cmd_opts
from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
@@ -182,7 +182,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
image_embs_name = None
if image_embs_name is not None and self.image_embs_name != image_embs_name:
self.image_embs_name = image_embs_name
- self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device)
+ self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
self.image_embs.requires_grad_(False)
--
cgit v1.2.3
From 763b893f319cee280b86e63025eb55e7c16b02e7 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Sun, 16 Oct 2022 10:03:09 +0800
Subject: images history sorting files by date
---
modules/images_history.py | 261 ++++++++++++++++++++++++++++++++++------------
1 file changed, 196 insertions(+), 65 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index f5ef44fe..533cf51b 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -1,33 +1,74 @@
import os
import shutil
+import time
+import hashlib
+import gradio
+show_max_dates_num = 3
+system_bak_path = "webui_log_and_bak"
+def is_valid_date(date):
+ try:
+ time.strptime(date, "%Y%m%d")
+ return True
+ except:
+ return False
+def reduplicative_file_move(src, dst):
+ def same_name_file(basename, path):
+ name, ext = os.path.splitext(basename)
+ f_list = os.listdir(path)
+ max_num = 0
+ for f in f_list:
+ if len(f) <= len(basename):
+ continue
+ f_ext = f[-len(ext):] if len(ext) > 0 else ""
+ if f[:len(name)] == name and f_ext == ext:
+ if f[len(name)] == "(" and f[-len(ext)-1] == ")":
+ number = f[len(name)+1:-len(ext)-1]
+ if number.isdigit():
+ if int(number) > max_num:
+ max_num = int(number)
+ return f"{name}({max_num + 1}){ext}"
+ name = os.path.basename(src)
+ save_name = os.path.join(dst, name)
+ if not os.path.exists(save_name):
+ shutil.move(src, dst)
+ else:
+ name = same_name_file(name, dst)
+ shutil.move(src, os.path.join(dst, name))
-def traverse_all_files(output_dir, image_list, curr_dir=None):
- curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
+def traverse_all_files(curr_path, image_list, all_type=False):
try:
f_list = os.listdir(curr_path)
except:
- if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
- image_list.append(curr_dir)
+ if all_type or curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt":
+ image_list.append(curr_path)
return image_list
for file in f_list:
- file = file if curr_dir is None else os.path.join(curr_dir, file)
- file_path = os.path.join(curr_path, file)
- if file[-4:] == ".txt":
+ file = os.path.join(curr_path, file)
+ if (not all_type) and file[-4:] == ".txt":
pass
- elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
+ elif os.path.isfile(file) and file[-10:].rfind(".") > 0:
image_list.append(file)
else:
- image_list = traverse_all_files(output_dir, image_list, file)
+ image_list = traverse_all_files(file, image_list)
return image_list
-
-def get_recent_images(dir_name, page_index, step, image_index, tabname):
- page_index = int(page_index)
- f_list = os.listdir(dir_name)
+def get_recent_images(dir_name, page_index, step, image_index, tabname, date_from, date_to):
+ #print(f"turn_page {page_index}",date_from)
+ if date_from is None or date_from == "":
+ return None, 1, None, ""
image_list = []
- image_list = traverse_all_files(dir_name, image_list)
- image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
+ date_list = auto_sorting(dir_name)
+ page_index = int(page_index)
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ for date in date_list:
+ if date >= date_from and date <= date_to:
+ path = os.path.join(dir_name, date)
+ if date == today and not os.path.exists(path):
+ continue
+ image_list = traverse_all_files(path, image_list)
+
+ image_list = sorted(image_list, key=lambda file: -os.path.getctime(file))
num = 48 if tabname != "extras" else 12
max_page_index = len(image_list) // num + 1
page_index = max_page_index if page_index == -1 else page_index + step
@@ -38,40 +79,101 @@ def get_recent_images(dir_name, page_index, step, image_index, tabname):
image_index = int(image_index)
if image_index < 0 or image_index > len(image_list) - 1:
current_file = None
- hidden = None
else:
- current_file = image_list[int(image_index)]
- hidden = os.path.join(dir_name, current_file)
- return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
+ current_file = image_list[image_index]
+ return image_list, page_index, image_list, ""
+def auto_sorting(dir_name):
+ #print(f"auto sorting")
+ bak_path = os.path.join(dir_name, system_bak_path)
+ if not os.path.exists(bak_path):
+ os.mkdir(bak_path)
+ log_file = None
+ files_list = []
+ f_list = os.listdir(dir_name)
+ for file in f_list:
+ if file == system_bak_path:
+ continue
+ file_path = os.path.join(dir_name, file)
+ if not is_valid_date(file):
+ if file[-10:].rfind(".") > 0:
+ files_list.append(file_path)
+ else:
+ files_list = traverse_all_files(file_path, files_list, all_type=True)
+
+ for file in files_list:
+ date_str = time.strftime("%Y%m%d",time.localtime(os.path.getctime(file)))
+ file_path = os.path.dirname(file)
+ hash_path = hashlib.md5(file_path.encode()).hexdigest()
+ path = os.path.join(dir_name, date_str, hash_path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ if log_file is None:
+ log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a")
+ log_file.write(f"{hash_path},{file_path}\n")
+ reduplicative_file_move(file, path)
+
+ date_list = []
+ f_list = os.listdir(dir_name)
+ for f in f_list:
+ if is_valid_date(f):
+ date_list.append(f)
+ elif f == system_bak_path:
+ continue
+ else:
+ reduplicative_file_move(os.path.join(dir_name, f), bak_path)
+
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ if today not in date_list:
+ date_list.append(today)
+ return sorted(date_list)
-def first_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, 1, 0, image_index, tabname)
-def end_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, -1, 0, image_index, tabname)
+def archive_images(dir_name):
+ date_list = auto_sorting(dir_name)
+ date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0]
+ return (
+ gradio.update(visible=False),
+ gradio.update(visible=True),
+ gradio.Dropdown.update(choices=date_list, value=date_list[-1]),
+ gradio.Dropdown.update(choices=date_list, value=date_from)
+ )
+def date_to_change(dir_name, page_index, image_index, tabname, date_from, date_to):
+ #print("date_to", date_to)
+ date_list = auto_sorting(dir_name)
+ date_from_list = [date for date in date_list if date <= date_to]
+ date_from = date_from_list[0] if len(date_from_list) < show_max_dates_num else date_from_list[-show_max_dates_num]
+ image_list, page_index, image_list, _ =get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to)
+ return image_list, page_index, image_list, _, gradio.Dropdown.update(choices=date_from_list, value=date_from)
-def prev_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, -1, image_index, tabname)
+def first_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
+ return get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to)
-def next_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, 1, image_index, tabname)
+def end_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
+ return get_recent_images(dir_name, -1, 0, image_index, tabname, date_from, date_to)
-def page_index_change(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, 0, image_index, tabname)
+def prev_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
+ return get_recent_images(dir_name, page_index, -1, image_index, tabname, date_from, date_to)
-def show_image_info(num, image_path, filenames):
- # print(f"select image {num}")
- file = filenames[int(num)]
- return file, num, os.path.join(image_path, file)
+def next_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
+ return get_recent_images(dir_name, page_index, 1, image_index, tabname, date_from, date_to)
+
+
+def page_index_change(dir_name, page_index, image_index, tabname, date_from, date_to):
+ return get_recent_images(dir_name, page_index, 0, image_index, tabname, date_from, date_to)
-def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
+def show_image_info(tabname_box, num, filenames):
+ # #print(f"select image {num}")
+ file = filenames[int(num)]
+ return file, num, file
+
+def delete_image(delete_num, tabname, name, page_index, filenames, image_index):
if name == "":
return filenames, delete_num
else:
@@ -81,21 +183,19 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima
new_file_list = []
for name in filenames:
if i >= index and i < index + delete_num:
- path = os.path.join(dir_name, name)
- if os.path.exists(path):
- print(f"Delete file {path}")
- os.remove(path)
- txt_file = os.path.splitext(path)[0] + ".txt"
+ if os.path.exists(name):
+ #print(f"Delete file {name}")
+ os.remove(name)
+ txt_file = os.path.splitext(name)[0] + ".txt"
if os.path.exists(txt_file):
os.remove(txt_file)
else:
- print(f"Not exists file {path}")
+ #print(f"Not exists file {name}")
else:
new_file_list.append(name)
i += 1
return new_file_list, 1
-
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
if tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples
@@ -107,16 +207,32 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
dir_name = d[0]
for p in d[1:]:
dir_name = os.path.join(dir_name, p)
- with gr.Row():
- renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
- first_page = gr.Button('First Page')
- prev_page = gr.Button('Prev Page')
- page_index = gr.Number(value=1, label="Page Index")
- next_page = gr.Button('Next Page')
- end_page = gr.Button('End Page')
- with gr.Row(elem_id=tabname + "_images_history"):
+
+ f_list = os.listdir(dir_name)
+ sorted_flag = os.path.exists(os.path.join(dir_name, system_bak_path)) or len(f_list) == 0
+ date_list, date_from, date_to = None, None, None
+ if sorted_flag:
+ #print(sorted_flag)
+ date_list = auto_sorting(dir_name)
+ date_to = date_list[-1]
+ date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0]
+
+ with gr.Column(visible=sorted_flag) as page_panel:
with gr.Row():
+ renew_page = gr.Button('Refresh', elem_id=tabname + "_images_history_renew_page", interactive=sorted_flag)
+ first_page = gr.Button('First Page')
+ prev_page = gr.Button('Prev Page')
+ page_index = gr.Number(value=1, label="Page Index")
+ next_page = gr.Button('Next Page')
+ end_page = gr.Button('End Page')
+
+ with gr.Row(elem_id=tabname + "_images_history"):
with gr.Column(scale=2):
+ with gr.Row():
+ newest = gr.Button('Newest')
+ date_to = gr.Dropdown(choices=date_list, value=date_to, label="Date to")
+ date_from = gr.Dropdown(choices=date_list, value=date_from, label="Date from")
+
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
with gr.Row():
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
@@ -128,22 +244,31 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Row():
with gr.Column():
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
- img_file_name = gr.Textbox(label="File Name", interactive=False)
- with gr.Row():
+ img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
# hiden items
+ with gr.Row(visible=False):
+ img_path = gr.Textbox(dir_name)
+ tabname_box = gr.Textbox(tabname)
+ image_index = gr.Textbox(value=-1)
+ set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
+ filenames = gr.State()
+ hidden = gr.Image(type="pil")
+ info1 = gr.Textbox()
+ info2 = gr.Textbox()
+ with gr.Column(visible=not sorted_flag) as init_warning:
+ with gr.Row():
+ gr.Textbox("The system needs to archive the files according to the date. This requires changing the directory structure of the files",
+ label="Waring",
+ css="")
+ with gr.Row():
+ sorted_button = gr.Button('Confirme')
- img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
- tabname_box = gr.Textbox(tabname, visible=False)
- image_index = gr.Textbox(value=-1, visible=False)
- set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
- filenames = gr.State()
- hidden = gr.Image(type="pil", visible=False)
- info1 = gr.Textbox(visible=False)
- info2 = gr.Textbox(visible=False)
-
+
+
+
# turn pages
- gallery_inputs = [img_path, page_index, image_index, tabname_box]
- gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
+ gallery_inputs = [img_path, page_index, image_index, tabname_box, date_from, date_to]
+ gallery_outputs = [history_gallery, page_index, filenames, img_file_name]
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
@@ -154,15 +279,21 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
# other funcitons
- set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
+ set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, filenames], outputs=[img_file_name, image_index, hidden])
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
- delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
+ delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
-
+ date_to.change(date_to_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs + [date_from])
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
+ sorted_button.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from])
+ newest.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from])
+
+
+
+
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
with gr.Blocks(analytics_enabled=False) as images_history:
--
cgit v1.2.3
From 523140d7805c644700009b8a2483ff4eb4a22304 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sun, 16 Oct 2022 10:23:30 +0200
Subject: ui fix
---
modules/aesthetic_clip.py | 3 +--
modules/sd_hijack.py | 3 +--
modules/shared.py | 2 ++
modules/ui.py | 24 ++++++++++++++----------
4 files changed, 18 insertions(+), 14 deletions(-)
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
index 68264284..ccb35c73 100644
--- a/modules/aesthetic_clip.py
+++ b/modules/aesthetic_clip.py
@@ -74,5 +74,4 @@ def generate_imgs_embd(name, folder, batch_size):
"""
shared.update_aesthetic_embeddings()
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
- value=sorted(shared.aesthetic_embeddings.keys())[0] if len(
- shared.aesthetic_embeddings) > 0 else None), res, ""
+ value="None"), res, ""
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 01fcb78f..2de2eed5 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -392,8 +392,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
- if len(text[
- 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None:
+ if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None:
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [
[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
diff --git a/modules/shared.py b/modules/shared.py
index 3c5ffef1..e2c98b2d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -96,11 +96,13 @@ loaded_hypernetwork = None
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
+aesthetic_embeddings = aesthetic_embeddings | {"None": None}
def update_aesthetic_embeddings():
global aesthetic_embeddings
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
+ aesthetic_embeddings = aesthetic_embeddings | {"None": None}
def reload_hypernetworks():
global hypernetworks
diff --git a/modules/ui.py b/modules/ui.py
index 13ba3142..4069f0d2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -594,19 +594,23 @@ def create_ui(wrap_gradio_gpu_call):
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Group():
- aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001")
-
- aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
- aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
+ with gr.Accordion("Open for Clip Aesthetic!",open=False):
+ with gr.Row():
+ aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
+ aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
- with gr.Row():
- aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
- aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
- aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
+ with gr.Row():
+ aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001")
+ aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
+ aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()),
+ label="Aesthetic imgs embedding",
+ value="None")
- aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Aesthetic imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
+ with gr.Row():
+ aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
+ aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
+ aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
- aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
--
cgit v1.2.3
From e4f8b5f00dd33b7547cc6b76fbed26bb83b37a64 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sun, 16 Oct 2022 10:28:21 +0200
Subject: ui fix
---
modules/sd_hijack.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 2de2eed5..5d0590af 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -178,7 +178,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.load_image_embs(image_embs_name)
def load_image_embs(self, image_embs_name):
- if image_embs_name is None or len(image_embs_name) == 0:
+ if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
image_embs_name = None
if image_embs_name is not None and self.image_embs_name != image_embs_name:
self.image_embs_name = image_embs_name
--
cgit v1.2.3
From f62905fdf928b54aa76765e5cbde8d538d494e49 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Sun, 16 Oct 2022 21:22:38 +0800
Subject: images history speed up
---
modules/images_history.py | 250 ++++++++++++++++++++++++----------------------
1 file changed, 128 insertions(+), 122 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 7fd75005..ae0b4e40 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -3,8 +3,10 @@ import shutil
import time
import hashlib
import gradio
-show_max_dates_num = 3
+
system_bak_path = "webui_log_and_bak"
+loads_files_num = 216
+num_of_imgs_per_page = 36
def is_valid_date(date):
try:
time.strptime(date, "%Y%m%d")
@@ -53,38 +55,7 @@ def traverse_all_files(curr_path, image_list, all_type=False):
image_list = traverse_all_files(file, image_list)
return image_list
-def get_recent_images(dir_name, page_index, step, image_index, tabname, date_from, date_to):
- #print(f"turn_page {page_index}",date_from)
- if date_from is None or date_from == "":
- return None, 1, None, ""
- image_list = []
- date_list = auto_sorting(dir_name)
- page_index = int(page_index)
- today = time.strftime("%Y%m%d",time.localtime(time.time()))
- for date in date_list:
- if date >= date_from and date <= date_to:
- path = os.path.join(dir_name, date)
- if date == today and not os.path.exists(path):
- continue
- image_list = traverse_all_files(path, image_list)
-
- image_list = sorted(image_list, key=lambda file: -os.path.getctime(file))
- num = 48 if tabname != "extras" else 12
- max_page_index = len(image_list) // num + 1
- page_index = max_page_index if page_index == -1 else page_index + step
- page_index = 1 if page_index < 1 else page_index
- page_index = max_page_index if page_index > max_page_index else page_index
- idx_frm = (page_index - 1) * num
- image_list = image_list[idx_frm:idx_frm + num]
- image_index = int(image_index)
- if image_index < 0 or image_index > len(image_list) - 1:
- current_file = None
- else:
- current_file = image_list[image_index]
- return image_list, page_index, image_list, ""
-
-def auto_sorting(dir_name):
- #print(f"auto sorting")
+def auto_sorting(dir_name):
bak_path = os.path.join(dir_name, system_bak_path)
if not os.path.exists(bak_path):
os.mkdir(bak_path)
@@ -126,102 +97,131 @@ def auto_sorting(dir_name):
today = time.strftime("%Y%m%d",time.localtime(time.time()))
if today not in date_list:
date_list.append(today)
- return sorted(date_list)
+ return sorted(date_list, reverse=True)
-def archive_images(dir_name):
+def archive_images(dir_name, date_to):
date_list = auto_sorting(dir_name)
- date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0]
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ date_to = today if date_to is None or date_to == "" else date_to
+ filenames = []
+ for date in date_list:
+ if date <= date_to:
+ path = os.path.join(dir_name, date)
+ if date == today and not os.path.exists(path):
+ continue
+ filenames = traverse_all_files(path, filenames)
+ if len(filenames) > loads_files_num:
+ break
+ filenames = sorted(filenames, key=lambda file: -os.path.getctime(file))
+ _, image_list, _, visible_num = get_recent_images(1, 0, filenames)
return (
gradio.update(visible=False),
gradio.update(visible=True),
- gradio.Dropdown.update(choices=date_list, value=date_list[-1]),
- gradio.Dropdown.update(choices=date_list, value=date_from)
+ gradio.Dropdown.update(choices=date_list, value=date_to),
+ date,
+ filenames,
+ 1,
+ image_list,
+ "",
+ visible_num
)
+def system_init(dir_name):
+ ret = [x for x in archive_images(dir_name, None)]
+ ret += [gradio.update(visible=False)]
+ return ret
+
+def newest_click(dir_name, date_to):
+ if date_to == "start":
+ return True, False, "start", None, None, 1, None, ""
+ else:
+ return archive_images(dir_name, time.strftime("%Y%m%d",time.localtime(time.time())))
-def date_to_change(dir_name, page_index, image_index, tabname, date_from, date_to):
- #print("date_to", date_to)
- date_list = auto_sorting(dir_name)
- date_from_list = [date for date in date_list if date <= date_to]
- date_from = date_from_list[0] if len(date_from_list) < show_max_dates_num else date_from_list[-show_max_dates_num]
- image_list, page_index, image_list, _ =get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to)
- return image_list, page_index, image_list, _, gradio.Dropdown.update(choices=date_from_list, value=date_from)
-
-def first_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
- return get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to)
-
-
-def end_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
- return get_recent_images(dir_name, -1, 0, image_index, tabname, date_from, date_to)
-
-
-def prev_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
- return get_recent_images(dir_name, page_index, -1, image_index, tabname, date_from, date_to)
-
-
-def next_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
- return get_recent_images(dir_name, page_index, 1, image_index, tabname, date_from, date_to)
-
-
-def page_index_change(dir_name, page_index, image_index, tabname, date_from, date_to):
- return get_recent_images(dir_name, page_index, 0, image_index, tabname, date_from, date_to)
-
-
-def show_image_info(tabname_box, num, filenames):
- # #print(f"select image {num}")
- file = filenames[int(num)]
- return file, num, file
-
-def delete_image(delete_num, tabname, name, page_index, filenames, image_index):
+def delete_image(delete_num, name, filenames, image_index, visible_num):
if name == "":
return filenames, delete_num
else:
delete_num = int(delete_num)
+ visible_num = int(visible_num)
+ image_index = int(image_index)
index = list(filenames).index(name)
i = 0
new_file_list = []
for name in filenames:
if i >= index and i < index + delete_num:
if os.path.exists(name):
- #print(f"Delete file {name}")
+ if visible_num == image_index:
+ new_file_list.append(name)
+ continue
+ print(f"Delete file {name}")
os.remove(name)
+ visible_num -= 1
txt_file = os.path.splitext(name)[0] + ".txt"
if os.path.exists(txt_file):
os.remove(txt_file)
else:
- #print(f"Not exists file {name}")
+ print(f"Not exists file {name}")
else:
new_file_list.append(name)
i += 1
- return new_file_list, 1
+ return new_file_list, 1, visible_num
+
+def get_recent_images(page_index, step, filenames):
+ page_index = int(page_index)
+ max_page_index = len(filenames) // num_of_imgs_per_page + 1
+ page_index = max_page_index if page_index == -1 else page_index + step
+ page_index = 1 if page_index < 1 else page_index
+ page_index = max_page_index if page_index > max_page_index else page_index
+ idx_frm = (page_index - 1) * num_of_imgs_per_page
+ image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page]
+ length = len(filenames)
+ visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page
+ visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
+ return page_index, image_list, "", visible_num
+
+def first_page_click(page_index, filenames):
+ return get_recent_images(1, 0, filenames)
+
+def end_page_click(page_index, filenames):
+ return get_recent_images(-1, 0, filenames)
+
+def prev_page_click(page_index, filenames):
+ return get_recent_images(page_index, -1, filenames)
+
+def next_page_click(page_index, filenames):
+ return get_recent_images(page_index, 1, filenames)
+
+def page_index_change(page_index, filenames):
+ return get_recent_images(page_index, 0, filenames)
+
+def show_image_info(tabname_box, num, page_index, filenames):
+ file = filenames[int(num) + int((page_index - 1) * num_of_imgs_per_page)]
+ return file, num, file
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
- if opts.outdir_samples != "":
- dir_name = opts.outdir_samples
- elif tabname == "txt2img":
+ if tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples
elif tabname == "img2img":
dir_name = opts.outdir_img2img_samples
elif tabname == "extras":
dir_name = opts.outdir_extras_samples
+ elif tabname == "saved":
+ dir_name = opts.outdir_save
+ if not os.path.exists(dir_name):
+ os.makedirs(dir_name)
d = dir_name.split("/")
- dir_name = "/" if dir_name.startswith("/") else d[0]
+ dir_name = d[0]
for p in d[1:]:
dir_name = os.path.join(dir_name, p)
f_list = os.listdir(dir_name)
sorted_flag = os.path.exists(os.path.join(dir_name, system_bak_path)) or len(f_list) == 0
date_list, date_from, date_to = None, None, None
- if sorted_flag:
- #print(sorted_flag)
- date_list = auto_sorting(dir_name)
- date_to = date_list[-1]
- date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0]
with gr.Column(visible=sorted_flag) as page_panel:
with gr.Row():
- renew_page = gr.Button('Refresh', elem_id=tabname + "_images_history_renew_page", interactive=sorted_flag)
+ #renew_page = gr.Button('Refresh')
first_page = gr.Button('First Page')
prev_page = gr.Button('Prev Page')
page_index = gr.Number(value=1, label="Page Index")
@@ -231,9 +231,9 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Row(elem_id=tabname + "_images_history"):
with gr.Column(scale=2):
with gr.Row():
- newest = gr.Button('Newest')
- date_to = gr.Dropdown(choices=date_list, value=date_to, label="Date to")
- date_from = gr.Dropdown(choices=date_list, value=date_from, label="Date from")
+ newest = gr.Button('Refresh', elem_id=tabname + "_images_history_start")
+ date_from = gr.Textbox(label="Date from", interactive=False)
+ date_to = gr.Dropdown(value="start" if not sorted_flag else None, label="Date to")
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
with gr.Row():
@@ -247,66 +247,72 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Column():
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
+
# hiden items
- with gr.Row(visible=False):
+ with gr.Row(visible=False):
+ visible_img_num = gr.Number()
img_path = gr.Textbox(dir_name)
tabname_box = gr.Textbox(tabname)
image_index = gr.Textbox(value=-1)
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
filenames = gr.State()
+ all_images_list = gr.State()
hidden = gr.Image(type="pil")
info1 = gr.Textbox()
info2 = gr.Textbox()
+
with gr.Column(visible=not sorted_flag) as init_warning:
with gr.Row():
- gr.Textbox("The system needs to archive the files according to the date. This requires changing the directory structure of the files",
- label="Waring",
- css="")
+ warning = gr.Textbox(
+ label="Waring",
+ value=f"The system needs to archive the files according to the date. This requires changing the directory structure of the files.If you have doubts about this operation, you can first back up the files in the '{dir_name}' directory"
+ )
+ warning.style(height=100, width=50)
with gr.Row():
sorted_button = gr.Button('Confirme')
-
-
+ change_date_output = [init_warning, page_panel, date_to, date_from, filenames, page_index, history_gallery, img_file_name, visible_img_num]
+ sorted_button.click(system_init, inputs=[img_path], outputs=change_date_output + [sorted_button])
+ newest.click(newest_click, inputs=[img_path, date_to], outputs=change_date_output)
+ date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output)
+ date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ newest.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+
+ delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
+ delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None)
+
# turn pages
- gallery_inputs = [img_path, page_index, image_index, tabname_box, date_from, date_to]
- gallery_outputs = [history_gallery, page_index, filenames, img_file_name]
+ gallery_inputs = [page_index, filenames]
+ gallery_outputs = [page_index, history_gallery, img_file_name, visible_img_num]
+
+ first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
- first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
+ first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
# other funcitons
- set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, filenames], outputs=[img_file_name, image_index, hidden])
- img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
- delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
+ set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden])
+ img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
- date_to.change(date_to_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs + [date_from])
- # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
+
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
- sorted_button.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from])
- newest.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from])
-
-
-
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs:
- with gr.Tab("txt2img history"):
- with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
- show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
- with gr.Tab("img2img history"):
- with gr.Blocks(analytics_enabled=False) as images_history_img2img:
- show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
- with gr.Tab("extras history"):
- with gr.Blocks(analytics_enabled=False) as images_history_img2img:
- show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
+ for tab in ["saved", "txt2img", "img2img", "extras"]:
+ with gr.Tab(tab):
+ with gr.Blocks(analytics_enabled=False) as images_history_img2img:
+ show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
return images_history
--
cgit v1.2.3
From a4de699e3c235d83b5a957d08779cb41cb0781bc Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Sun, 16 Oct 2022 22:37:12 +0800
Subject: Images history speed up
---
modules/images_history.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index ae0b4e40..94bd16a8 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -153,6 +153,7 @@ def delete_image(delete_num, name, filenames, image_index, visible_num):
if os.path.exists(name):
if visible_num == image_index:
new_file_list.append(name)
+ i += 1
continue
print(f"Delete file {name}")
os.remove(name)
@@ -221,7 +222,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Column(visible=sorted_flag) as page_panel:
with gr.Row():
- #renew_page = gr.Button('Refresh')
+ renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
first_page = gr.Button('First Page')
prev_page = gr.Button('Prev Page')
page_index = gr.Number(value=1, label="Page Index")
@@ -231,7 +232,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Row(elem_id=tabname + "_images_history"):
with gr.Column(scale=2):
with gr.Row():
- newest = gr.Button('Refresh', elem_id=tabname + "_images_history_start")
+ newest = gr.Button('Reload', elem_id=tabname + "_images_history_start")
date_from = gr.Textbox(label="Date from", interactive=False)
date_to = gr.Dropdown(value="start" if not sorted_flag else None, label="Date to")
@@ -291,12 +292,14 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
+ renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
# other funcitons
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden])
--
cgit v1.2.3
From 9324cdaa3199d65c182858785dd1eca42b192b8e Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sun, 16 Oct 2022 17:53:56 +0200
Subject: ui fix, re organization of the code
---
modules/aesthetic_clip.py | 154 +++++++++++++++++++++++++++++++++--
modules/img2img.py | 14 +++-
modules/processing.py | 29 ++-----
modules/sd_hijack.py | 102 ++---------------------
modules/sd_models.py | 5 +-
modules/shared.py | 14 +++-
modules/textual_inversion/dataset.py | 2 +-
modules/txt2img.py | 18 ++--
modules/ui.py | 52 +++++++-----
9 files changed, 233 insertions(+), 157 deletions(-)
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
index ccb35c73..34efa931 100644
--- a/modules/aesthetic_clip.py
+++ b/modules/aesthetic_clip.py
@@ -1,3 +1,4 @@
+import copy
import itertools
import os
from pathlib import Path
@@ -7,11 +8,12 @@ import gc
import gradio as gr
import torch
from PIL import Image
-from modules import shared
-from modules.shared import device
-from transformers import CLIPModel, CLIPProcessor
+from torch import optim
-from tqdm.auto import tqdm
+from modules import shared
+from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
+from tqdm.auto import tqdm, trange
+from modules.shared import opts, device
def get_all_images_in_folder(folder):
@@ -37,12 +39,39 @@ def iter_to_batched(iterable, n=1):
yield chunk
+def create_ui():
+ with gr.Group():
+ with gr.Accordion("Open for Clip Aesthetic!", open=False):
+ with gr.Row():
+ aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
+ value=0.9)
+ aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
+
+ with gr.Row():
+ aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
+ placeholder="Aesthetic learning rate", value="0.0001")
+ aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
+ aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
+ label="Aesthetic imgs embedding",
+ value="None")
+
+ with gr.Row():
+ aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
+ placeholder="This text is used to rotate the feature space of the imgs embs",
+ value="")
+ aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
+ value=0.1)
+ aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
+
+ return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
+
+
def generate_imgs_embd(name, folder, batch_size):
# clipModel = CLIPModel.from_pretrained(
# shared.sd_model.cond_stage_model.clipModel.name_or_path
# )
- model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device)
- processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path)
+ model = shared.clip_model.to(device)
+ processor = CLIPProcessor.from_pretrained(model.name_or_path)
with torch.no_grad():
embs = []
@@ -63,7 +92,6 @@ def generate_imgs_embd(name, folder, batch_size):
torch.save(embs, path)
model = model.cpu()
- del model
del processor
del embs
gc.collect()
@@ -74,4 +102,114 @@ def generate_imgs_embd(name, folder, batch_size):
"""
shared.update_aesthetic_embeddings()
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
- value="None"), res, ""
+ value="None"), \
+ gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
+ label="Imgs embedding",
+ value="None"), res, ""
+
+
+def slerp(low, high, val):
+ low_norm = low / torch.norm(low, dim=1, keepdim=True)
+ high_norm = high / torch.norm(high, dim=1, keepdim=True)
+ omega = torch.acos((low_norm * high_norm).sum(1))
+ so = torch.sin(omega)
+ res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
+ return res
+
+
+class AestheticCLIP:
+ def __init__(self):
+ self.skip = False
+ self.aesthetic_steps = 0
+ self.aesthetic_weight = 0
+ self.aesthetic_lr = 0
+ self.slerp = False
+ self.aesthetic_text_negative = ""
+ self.aesthetic_slerp_angle = 0
+ self.aesthetic_imgs_text = ""
+
+ self.image_embs_name = None
+ self.image_embs = None
+ self.load_image_embs(None)
+
+ def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
+ aesthetic_slerp=True, aesthetic_imgs_text="",
+ aesthetic_slerp_angle=0.15,
+ aesthetic_text_negative=False):
+ self.aesthetic_imgs_text = aesthetic_imgs_text
+ self.aesthetic_slerp_angle = aesthetic_slerp_angle
+ self.aesthetic_text_negative = aesthetic_text_negative
+ self.slerp = aesthetic_slerp
+ self.aesthetic_lr = aesthetic_lr
+ self.aesthetic_weight = aesthetic_weight
+ self.aesthetic_steps = aesthetic_steps
+ self.load_image_embs(image_embs_name)
+
+ def set_skip(self, skip):
+ self.skip = skip
+
+ def load_image_embs(self, image_embs_name):
+ if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
+ image_embs_name = None
+ self.image_embs_name = None
+ if image_embs_name is not None and self.image_embs_name != image_embs_name:
+ self.image_embs_name = image_embs_name
+ self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
+ self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
+ self.image_embs.requires_grad_(False)
+
+ def __call__(self, z, remade_batch_tokens):
+ if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
+ tokenizer = shared.sd_model.cond_stage_model.tokenizer
+ if not opts.use_old_emphasis_implementation:
+ remade_batch_tokens = [
+ [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
+ remade_batch_tokens]
+
+ tokens = torch.asarray(remade_batch_tokens).to(device)
+
+ model = copy.deepcopy(shared.clip_model).to(device)
+ model.requires_grad_(True)
+ if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
+ text_embs_2 = model.get_text_features(
+ **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
+ if self.aesthetic_text_negative:
+ text_embs_2 = self.image_embs - text_embs_2
+ text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
+ img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
+ else:
+ img_embs = self.image_embs
+
+ with torch.enable_grad():
+
+ # We optimize the model to maximize the similarity
+ optimizer = optim.Adam(
+ model.text_model.parameters(), lr=self.aesthetic_lr
+ )
+
+ for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
+ text_embs = model.get_text_features(input_ids=tokens)
+ text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
+ sim = text_embs @ img_embs.T
+ loss = -sim
+ optimizer.zero_grad()
+ loss.mean().backward()
+ optimizer.step()
+
+ zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+ if opts.CLIP_stop_at_last_layers > 1:
+ zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
+ zn = model.text_model.final_layer_norm(zn)
+ else:
+ zn = zn.last_hidden_state
+ model.cpu()
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
+ zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
+ if self.slerp:
+ z = slerp(z, zn, self.aesthetic_weight)
+ else:
+ z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
+
+ return z
diff --git a/modules/img2img.py b/modules/img2img.py
index 24126774..4ed80c4b 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -56,7 +56,14 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str,
+ aesthetic_lr=0,
+ aesthetic_weight=0, aesthetic_steps=0,
+ aesthetic_imgs=None,
+ aesthetic_slerp=False,
+ aesthetic_imgs_text="",
+ aesthetic_slerp_angle=0.15,
+ aesthetic_text_negative=False, *args):
is_inpaint = mode == 1
is_batch = mode == 2
@@ -109,6 +116,11 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert,
)
+ shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
+ aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text,
+ aesthetic_slerp_angle,
+ aesthetic_text_negative)
+
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/processing.py b/modules/processing.py
index 1db26c3e..685f9fcd 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -146,7 +146,8 @@ class Processed:
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
- self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
+ self.subseed = int(
+ self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed]
@@ -332,16 +333,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
-def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0,
- aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="",
- aesthetic_slerp_angle=0.15,
- aesthetic_text_negative=False) -> Processed:
+def process_images(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
- aesthetic_lr = float(aesthetic_lr)
- aesthetic_weight = float(aesthetic_weight)
- aesthetic_steps = int(aesthetic_steps)
-
if type(p.prompt) == list:
assert (len(p.prompt) > 0)
else:
@@ -417,16 +411,10 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
# uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
# c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
- if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
- shared.sd_model.cond_stage_model.set_aesthetic_params()
+ shared.aesthetic_clip.set_skip(True)
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt],
p.steps)
- if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
- shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight,
- aesthetic_steps, aesthetic_imgs,
- aesthetic_slerp, aesthetic_imgs_text,
- aesthetic_slerp_angle,
- aesthetic_text_negative)
+ shared.aesthetic_clip.set_skip(False)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
@@ -582,7 +570,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
-
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
@@ -600,10 +587,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
- samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+ samples = samples[:, :, self.truncate_y // 2:samples.shape[2] - self.truncate_y // 2,
+ self.truncate_x // 2:samples.shape[3] - self.truncate_x // 2]
if opts.use_scale_latent_for_hires_fix:
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f),
+ mode="bilinear")
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 5d0590af..227e7670 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -29,8 +29,8 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu
-
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (
+ 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
@@ -118,33 +118,14 @@ class StableDiffusionModelHijack:
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
-def slerp(low, high, val):
- low_norm = low / torch.norm(low, dim=1, keepdim=True)
- high_norm = high / torch.norm(high, dim=1, keepdim=True)
- omega = torch.acos((low_norm * high_norm).sum(1))
- so = torch.sin(omega)
- res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
- return res
-
-
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
- self.clipModel = CLIPModel.from_pretrained(
- self.wrapped.transformer.name_or_path
- )
- del self.clipModel.vision_model
- self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path)
- self.hijack: StableDiffusionModelHijack = hijack
- self.tokenizer = wrapped.tokenizer
- # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval()
- self.image_embs_name = None
- self.image_embs = None
- self.load_image_embs(None)
self.token_mults = {}
-
+ self.hijack: StableDiffusionModelHijack = hijack
+ self.tokenizer = wrapped.tokenizer
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0]
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if
@@ -164,28 +145,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
- def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
- aesthetic_slerp=True, aesthetic_imgs_text="",
- aesthetic_slerp_angle=0.15,
- aesthetic_text_negative=False):
- self.aesthetic_imgs_text = aesthetic_imgs_text
- self.aesthetic_slerp_angle = aesthetic_slerp_angle
- self.aesthetic_text_negative = aesthetic_text_negative
- self.slerp = aesthetic_slerp
- self.aesthetic_lr = aesthetic_lr
- self.aesthetic_weight = aesthetic_weight
- self.aesthetic_steps = aesthetic_steps
- self.load_image_embs(image_embs_name)
-
- def load_image_embs(self, image_embs_name):
- if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
- image_embs_name = None
- if image_embs_name is not None and self.image_embs_name != image_embs_name:
- self.image_embs_name = image_embs_name
- self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
- self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
- self.image_embs.requires_grad_(False)
-
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_end = self.wrapped.tokenizer.eos_token_id
@@ -391,58 +350,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
- if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None:
- if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [
- [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
- remade_batch_tokens]
-
- tokens = torch.asarray(remade_batch_tokens).to(device)
-
- model = copy.deepcopy(self.clipModel).to(device)
- model.requires_grad_(True)
- if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
- text_embs_2 = model.get_text_features(
- **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
- if self.aesthetic_text_negative:
- text_embs_2 = self.image_embs - text_embs_2
- text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
- img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
- else:
- img_embs = self.image_embs
-
- with torch.enable_grad():
-
- # We optimize the model to maximize the similarity
- optimizer = optim.Adam(
- model.text_model.parameters(), lr=self.aesthetic_lr
- )
-
- for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
- text_embs = model.get_text_features(input_ids=tokens)
- text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
- sim = text_embs @ img_embs.T
- loss = -sim
- optimizer.zero_grad()
- loss.mean().backward()
- optimizer.step()
-
- zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
- if opts.CLIP_stop_at_last_layers > 1:
- zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
- zn = model.text_model.final_layer_norm(zn)
- else:
- zn = zn.last_hidden_state
- model.cpu()
- del model
-
- zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1)
- if self.slerp:
- z = slerp(z, zn, self.aesthetic_weight)
- else:
- z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
-
+ z = shared.aesthetic_clip(z, remade_batch_tokens)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 3aa21ec1..8e4ee435 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -20,7 +20,7 @@ checkpoints_loaded = collections.OrderedDict()
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
- from transformers import logging
+ from transformers import logging, CLIPModel
logging.set_verbosity_error()
except Exception:
@@ -196,6 +196,9 @@ def load_model():
sd_hijack.model_hijack.hijack(sd_model)
+ if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path:
+ shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path)
+
sd_model.eval()
print(f"Model loaded.")
diff --git a/modules/shared.py b/modules/shared.py
index e2c98b2d..e19ca779 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -3,6 +3,7 @@ import datetime
import json
import os
import sys
+from collections import OrderedDict
import gradio as gr
import tqdm
@@ -94,15 +95,15 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
-aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
- os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
-aesthetic_embeddings = aesthetic_embeddings | {"None": None}
+aesthetic_embeddings = {}
def update_aesthetic_embeddings():
global aesthetic_embeddings
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
- aesthetic_embeddings = aesthetic_embeddings | {"None": None}
+ aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings)
+
+update_aesthetic_embeddings()
def reload_hypernetworks():
global hypernetworks
@@ -381,6 +382,11 @@ sd_upscalers = []
sd_model = None
+clip_model = None
+
+from modules.aesthetic_clip import AestheticCLIP
+aesthetic_clip = AestheticCLIP()
+
progress_print_out = sys.stdout
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 68ceffe3..23bb4b6a 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -49,7 +49,7 @@ class PersonalizedBase(Dataset):
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
try:
- image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC)
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception:
continue
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 8f394d05..6cbc50fc 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,12 +1,17 @@
import modules.scripts
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
+ StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int,aesthetic_lr=0,
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int,
+ restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int,
+ subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool,
+ height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int,
+ firstphase_height: int, aesthetic_lr=0,
aesthetic_weight=0, aesthetic_steps=0,
aesthetic_imgs=None,
aesthetic_slerp=False,
@@ -41,15 +46,17 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None,
)
+ shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
+ aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle,
+ aesthetic_text_negative)
+
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
- processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text,
- aesthetic_slerp_angle,
- aesthetic_text_negative)
+ processed = process_images(p)
shared.total_tqdm.clear()
@@ -61,4 +68,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info)
-
diff --git a/modules/ui.py b/modules/ui.py
index 4069f0d2..0e5d73f0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -43,7 +43,7 @@ from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui
-import modules.aesthetic_clip
+import modules.aesthetic_clip as aesthetic_clip
import modules.images_history as img_his
@@ -593,23 +593,25 @@ def create_ui(wrap_gradio_gpu_call):
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
- with gr.Group():
- with gr.Accordion("Open for Clip Aesthetic!",open=False):
- with gr.Row():
- aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
- aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
-
- with gr.Row():
- aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001")
- aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
- aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()),
- label="Aesthetic imgs embedding",
- value="None")
-
- with gr.Row():
- aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
- aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
- aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
+ # with gr.Group():
+ # with gr.Accordion("Open for Clip Aesthetic!",open=False):
+ # with gr.Row():
+ # aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
+ # aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
+ #
+ # with gr.Row():
+ # aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001")
+ # aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
+ # aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()),
+ # label="Aesthetic imgs embedding",
+ # value="None")
+ #
+ # with gr.Row():
+ # aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
+ # aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
+ # aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
+
+ aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui()
with gr.Row():
@@ -840,6 +842,9 @@ def create_ui(wrap_gradio_gpu_call):
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui()
+
+
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False)
@@ -944,6 +949,14 @@ def create_ui(wrap_gradio_gpu_call):
inpainting_mask_invert,
img2img_batch_input_dir,
img2img_batch_output_dir,
+ aesthetic_lr_im,
+ aesthetic_weight_im,
+ aesthetic_steps_im,
+ aesthetic_imgs_im,
+ aesthetic_slerp_im,
+ aesthetic_imgs_text_im,
+ aesthetic_slerp_angle_im,
+ aesthetic_text_negative_im,
] + custom_inputs,
outputs=[
img2img_gallery,
@@ -1283,7 +1296,7 @@ def create_ui(wrap_gradio_gpu_call):
)
create_embedding_ae.click(
- fn=modules.aesthetic_clip.generate_imgs_embd,
+ fn=aesthetic_clip.generate_imgs_embd,
inputs=[
new_embedding_name_ae,
process_src_ae,
@@ -1291,6 +1304,7 @@ def create_ui(wrap_gradio_gpu_call):
],
outputs=[
aesthetic_imgs,
+ aesthetic_imgs_im,
ti_output,
ti_outcome,
]
--
cgit v1.2.3
From 9d702b16f01795c3af900e0ebd70faf4b25200f6 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 17 Oct 2022 16:11:03 +0800
Subject: fix two little bug
---
modules/images_history.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 23045df1..1ae168ca 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -133,7 +133,7 @@ def archive_images(dir_name, date_to):
date = sort_array[loads_num][2]
filenames = [x[1] for x in sort_array]
else:
- date = sort_array[loads_num][2]
+ date = sort_array[-1][2]
filenames = [x[1] for x in sort_array]
filenames = [x[1] for x in sort_array if x[2]>= date]
_, image_list, _, visible_num = get_recent_images(1, 0, filenames)
@@ -334,6 +334,6 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict):
with gr.Tab(tab):
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
- gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory") #, visible=False)
+ gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False)
return images_history
--
cgit v1.2.3
From c408a0b41cfffde184cad35b2d97346342947d83 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 17 Oct 2022 22:28:43 +0800
Subject: fix two bug
---
modules/images_history.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 1ae168ca..10e5b970 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -181,7 +181,8 @@ def delete_image(delete_num, name, filenames, image_index, visible_num):
return new_file_list, 1, visible_num
def save_image(file_name):
- shutil.copy2(file_name, opts.outdir_save)
+ if file_name is not None and os.path.exists(file_name):
+ shutil.copy2(file_name, opts.outdir_save)
def get_recent_images(page_index, step, filenames):
page_index = int(page_index)
@@ -327,7 +328,6 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict):
opts = sys_opts
loads_files_num = int(opts.images_history_num_per_page)
num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
- backup_flag = opts.images_history_backup
with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs:
for tab in ["txt2img", "img2img", "extras", "saved"]:
--
cgit v1.2.3
From 2272cf2f35fafd5cd486bfb4ee89df5bbc625b97 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 17 Oct 2022 23:04:42 +0800
Subject: fix two bug
---
modules/images_history.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 10e5b970..1c1790a4 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -133,7 +133,7 @@ def archive_images(dir_name, date_to):
date = sort_array[loads_num][2]
filenames = [x[1] for x in sort_array]
else:
- date = sort_array[-1][2]
+ date = None if len(sort_array) == 0 else sort_array[-1][2]
filenames = [x[1] for x in sort_array]
filenames = [x[1] for x in sort_array if x[2]>= date]
_, image_list, _, visible_num = get_recent_images(1, 0, filenames)
--
cgit v1.2.3
From 2b5b62e768d892773a7ec1d5e8d8cea23aae1254 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 17 Oct 2022 23:14:03 +0800
Subject: fix two bug
---
modules/images_history.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 1c1790a4..20324557 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -44,7 +44,7 @@ def traverse_all_files(curr_path, image_list, all_type=False):
return image_list
for file in f_list:
file = os.path.join(curr_path, file)
- if (not all_type) and file[-4:] == ".txt":
+ if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"):
pass
elif os.path.isfile(file) and file[-10:].rfind(".") > 0:
image_list.append(file)
@@ -182,7 +182,7 @@ def delete_image(delete_num, name, filenames, image_index, visible_num):
def save_image(file_name):
if file_name is not None and os.path.exists(file_name):
- shutil.copy2(file_name, opts.outdir_save)
+ shutil.copy(file_name, opts.outdir_save)
def get_recent_images(page_index, step, filenames):
page_index = int(page_index)
--
cgit v1.2.3
From eb299527b1e5d1f83a14641647fca72e8fb305ac Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Tue, 18 Oct 2022 20:14:11 +0800
Subject: Image browser
---
modules/images_history.py | 227 ++++++++++++++++++++++++++++++----------------
modules/shared.py | 7 +-
modules/ui.py | 2 +-
3 files changed, 154 insertions(+), 82 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 20324557..d56f3a25 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -4,6 +4,7 @@ import time
import hashlib
import gradio
system_bak_path = "webui_log_and_bak"
+browser_tabname = "custom"
def is_valid_date(date):
try:
time.strptime(date, "%Y%m%d")
@@ -99,13 +100,15 @@ def auto_sorting(dir_name):
date_list.append(today)
return sorted(date_list, reverse=True)
-def archive_images(dir_name, date_to):
+def archive_images(dir_name, date_to):
+
filenames = []
loads_num =int(opts.images_history_num_per_page * opts.images_history_pages_num)
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ date_to = today if date_to is None or date_to == "" else date_to
+ date_to_bak = date_to
if opts.images_history_reconstruct_directory:
- date_list = auto_sorting(dir_name)
- today = time.strftime("%Y%m%d",time.localtime(time.time()))
- date_to = today if date_to is None or date_to == "" else date_to
+ date_list = auto_sorting(dir_name)
for date in date_list:
if date <= date_to:
path = os.path.join(dir_name, date)
@@ -120,7 +123,7 @@ def archive_images(dir_name, date_to):
tmparray = [(os.path.getmtime(file), file) for file in filenames ]
date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
filenames = []
- date_list = {}
+ date_list = {date_to:None}
date = time.strftime("%Y%m%d",time.localtime(time.time()))
for t, f in tmparray:
date = time.strftime("%Y%m%d",time.localtime(t))
@@ -133,22 +136,29 @@ def archive_images(dir_name, date_to):
date = sort_array[loads_num][2]
filenames = [x[1] for x in sort_array]
else:
- date = None if len(sort_array) == 0 else sort_array[-1][2]
+ date = date_to if len(sort_array) == 0 else sort_array[-1][2]
filenames = [x[1] for x in sort_array]
- filenames = [x[1] for x in sort_array if x[2]>= date]
- _, image_list, _, visible_num = get_recent_images(1, 0, filenames)
+ filenames = [x[1] for x in sort_array if x[2]>= date]
+ num = len(filenames)
+ last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
+ date = date[:4] + "-" + date[4:6] + "-" + date[6:8]
+ date_to_bak = date_to_bak[:4] + "-" + date_to_bak[4:6] + "-" + date_to_bak[6:8]
+ load_info = f"Loaded {(num + 1) // opts.images_history_pages_num} pades, {num} images, during {date} - {date_to_bak}"
+ _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
return (
gradio.Dropdown.update(choices=date_list, value=date_to),
- date,
+ load_info,
filenames,
1,
image_list,
"",
- visible_num
+ "",
+ visible_num,
+ last_date_from
)
-def newest_click(dir_name, date_to):
- return archive_images(dir_name, time.strftime("%Y%m%d",time.localtime(time.time())))
+
+
def delete_image(delete_num, name, filenames, image_index, visible_num):
if name == "":
@@ -196,7 +206,29 @@ def get_recent_images(page_index, step, filenames):
length = len(filenames)
visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page
visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
- return page_index, image_list, "", visible_num
+ return page_index, image_list, "", "", visible_num
+
+def newest_click(date_to):
+ if date_to is None:
+ return time.strftime("%Y%m%d",time.localtime(time.time())), []
+ else:
+ return None, []
+def forward_click(last_date_from, date_to_recorder):
+ if len(date_to_recorder) == 0:
+ return None, []
+ if last_date_from == date_to_recorder[-1]:
+ date_to_recorder = date_to_recorder[:-1]
+ if len(date_to_recorder) == 0:
+ return None, []
+ return date_to_recorder[-1], date_to_recorder[:-1]
+
+def backward_click(last_date_from, date_to_recorder):
+ if last_date_from is None or last_date_from == "":
+ return time.strftime("%Y%m%d",time.localtime(time.time())), []
+ if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]:
+ date_to_recorder.append(last_date_from)
+ return last_date_from, date_to_recorder
+
def first_page_click(page_index, filenames):
return get_recent_images(1, 0, filenames)
@@ -214,13 +246,33 @@ def page_index_change(page_index, filenames):
return get_recent_images(page_index, 0, filenames)
def show_image_info(tabname_box, num, page_index, filenames):
- file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
- return file, num, file
+ file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
+ tm = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file)))
+ return file, tm, num, file
def enable_page_buttons():
return gradio.update(visible=True)
+def change_dir(img_dir, date_to):
+ warning = None
+ try:
+ if os.path.exists(img_dir):
+ try:
+ f = os.listdir(img_dir)
+ except:
+ warning = f"'{img_dir} is not a directory"
+ else:
+ warning = "The directory is not exist"
+ except:
+ warning = "The format of the directory is incorrect"
+ if warning is None:
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today
+ else:
+ return gradio.update(visible=True), gradio.update(visible=False), warning, date_to
+
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
+ custom_dir = False
if tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples
elif tabname == "img2img":
@@ -229,69 +281,85 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
dir_name = opts.outdir_extras_samples
elif tabname == "saved":
dir_name = opts.outdir_save
+ else:
+ custom_dir = True
+ dir_name = None
+
+ if not custom_dir:
+ d = dir_name.split("/")
+ dir_name = d[0]
+ for p in d[1:]:
+ dir_name = os.path.join(dir_name, p)
+ if not os.path.exists(dir_name):
+ os.makedirs(dir_name)
- d = dir_name.split("/")
- dir_name = d[0]
- for p in d[1:]:
- dir_name = os.path.join(dir_name, p)
- if not os.path.exists(dir_name):
- os.makedirs(dir_name)
-
- with gr.Column() as page_panel:
- with gr.Row(visible=False) as turn_page_buttons:
- renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
- first_page = gr.Button('First Page')
- prev_page = gr.Button('Prev Page')
- page_index = gr.Number(value=1, label="Page Index")
- next_page = gr.Button('Next Page')
- end_page = gr.Button('End Page')
-
- with gr.Row(elem_id=tabname + "_images_history"):
- with gr.Column(scale=2):
- with gr.Row():
- newest = gr.Button('Reload', elem_id=tabname + "_images_history_start")
- date_from = gr.Textbox(label="Date from", interactive=False)
- date_to = gr.Dropdown(label="Date to")
-
- history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
- with gr.Row():
- delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
- delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
-
- with gr.Column():
- with gr.Row():
- if tabname != "saved":
- save_btn = gr.Button('Save')
- pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
- pnginfo_send_to_img2img = gr.Button('Send to img2img')
- with gr.Row():
- with gr.Column():
- img_file_info = gr.Textbox(label="Generate Info", interactive=False)
- img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
+ with gr.Column() as page_panel:
+ with gr.Row():
+ img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory")
+ with gr.Row(visible=False) as warning:
+ warning_box = gr.Textbox("Message", interactive=False)
+ with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
+ with gr.Column(scale=2):
+ with gr.Row():
+ backward = gr.Button('Backward')
+ date_to = gr.Dropdown(label="Date to")
+ forward = gr.Button('Forward')
+ newest = gr.Button('Reload', elem_id=tabname + "_images_history_start")
+ with gr.Row():
+ load_info = gr.Textbox(show_label=False, interactive=False)
+ with gr.Row(visible=False) as turn_page_buttons:
+ renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
+ first_page = gr.Button('First Page')
+ prev_page = gr.Button('Prev Page')
+ page_index = gr.Number(value=1, label="Page Index")
+ next_page = gr.Button('Next Page')
+ end_page = gr.Button('End Page')
+
+ history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num)
+ with gr.Row():
+ delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
+ delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
- # hiden items
- with gr.Row(visible=False):
- visible_img_num = gr.Number()
- img_path = gr.Textbox(dir_name)
- tabname_box = gr.Textbox(tabname)
- image_index = gr.Textbox(value=-1)
- set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
- filenames = gr.State()
- all_images_list = gr.State()
- hidden = gr.Image(type="pil")
- info1 = gr.Textbox()
- info2 = gr.Textbox()
+ with gr.Column():
+ with gr.Row():
+ if tabname != "saved":
+ save_btn = gr.Button('Save')
+ pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
+ pnginfo_send_to_img2img = gr.Button('Send to img2img')
+ with gr.Row():
+ with gr.Column():
+ img_file_info = gr.Textbox(label="Generate Info", interactive=False)
+ img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
+ img_file_time= gr.Textbox(value="", label="Create Time", interactive=False)
-
+
+ # hiden items
+ with gr.Row(): #visible=False):
+ visible_img_num = gr.Number()
+ date_to_recorder = gr.State([])
+ last_date_from = gr.Textbox()
+ tabname_box = gr.Textbox(tabname)
+ image_index = gr.Textbox(value=-1)
+ set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
+ filenames = gr.State()
+ all_images_list = gr.State()
+ hidden = gr.Image(type="pil")
+ info1 = gr.Textbox()
+ info2 = gr.Textbox()
+
+ img_path.submit(change_dir, inputs=[img_path, date_to], outputs=[warning, main_panel, warning_box, date_to])
#change date
- change_date_output = [date_to, date_from, filenames, page_index, history_gallery, img_file_name, visible_img_num]
- newest.click(newest_click, inputs=[img_path, date_to], outputs=change_date_output)
- date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output)
- newest.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
- date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
- date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
- newest.click(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
+ change_date_output = [date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from]
+
+ date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output)
+ date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
+ date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+
+ newest.click(newest_click, inputs=[date_to], outputs=[date_to, date_to_recorder])
+ forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder])
+ backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder])
+
#delete
delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
@@ -301,7 +369,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
#turn page
gallery_inputs = [page_index, filenames]
- gallery_outputs = [page_index, history_gallery, img_file_name, visible_img_num]
+ gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num]
first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
@@ -317,12 +385,14 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
# other funcitons
- set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden])
+ set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden])
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
+
+
def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict):
global opts;
opts = sys_opts
@@ -330,10 +400,11 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict):
num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs:
- for tab in ["txt2img", "img2img", "extras", "saved"]:
+ for tab in [browser_tabname, "txt2img", "img2img", "extras", "saved"]:
with gr.Tab(tab):
- with gr.Blocks(analytics_enabled=False) as images_history_img2img:
+ with gr.Blocks(analytics_enabled=False) :
show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
- gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False)
+ #gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False)
+ gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_finish_render", visible=False)
return images_history
diff --git a/modules/shared.py b/modules/shared.py
index c2ea4186..1811018d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -309,10 +309,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
-options_templates.update(options_section(('images-history', "Images history"), {
- "images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
+options_templates.update(options_section(('images-history', "Images Browser"), {
+ #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
"images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"),
- "images_history_pages_num": OptionInfo(6, "Maximum number of pages per load "),
+ "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "),
+ "images_history_grid_num": OptionInfo(6, "Number of grids in each row"),
}))
diff --git a/modules/ui.py b/modules/ui.py
index 43dc88fc..85abac4d 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1548,7 +1548,7 @@ Requested path was: {f}
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
- (images_history, "History", "images_history"),
+ (images_history, "Image Browser", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
(settings_interface, "Settings", "settings"),
--
cgit v1.2.3
From b7e78ef692fe912916de6e54f6e2521b000d650c Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Tue, 18 Oct 2022 22:21:54 +0800
Subject: Image browser improve
---
modules/images_history.py | 43 ++++++++++++++++++++++---------------------
1 file changed, 22 insertions(+), 21 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index d56f3a25..a40cdc0e 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -100,14 +100,15 @@ def auto_sorting(dir_name):
date_list.append(today)
return sorted(date_list, reverse=True)
-def archive_images(dir_name, date_to):
-
+def archive_images(dir_name, date_to):
filenames = []
- loads_num =int(opts.images_history_num_per_page * opts.images_history_pages_num)
+ batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num)
+ if batch_size <= 0:
+ batch_size = opts.images_history_num_per_page * 6
today = time.strftime("%Y%m%d",time.localtime(time.time()))
date_to = today if date_to is None or date_to == "" else date_to
date_to_bak = date_to
- if opts.images_history_reconstruct_directory:
+ if False: #opts.images_history_reconstruct_directory:
date_list = auto_sorting(dir_name)
for date in date_list:
if date <= date_to:
@@ -115,11 +116,13 @@ def archive_images(dir_name, date_to):
if date == today and not os.path.exists(path):
continue
filenames = traverse_all_files(path, filenames)
- if len(filenames) > loads_num:
+ if len(filenames) > batch_size:
break
filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file))
else:
- filenames = traverse_all_files(dir_name, filenames)
+ filenames = traverse_all_files(dir_name, filenames)
+ total_num = len(filenames)
+ batch_count = len(filenames) + 1 // batch_size + 1
tmparray = [(os.path.getmtime(file), file) for file in filenames ]
date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
filenames = []
@@ -132,8 +135,8 @@ def archive_images(dir_name, date_to):
filenames.append((t, f ,date))
date_list = sorted(list(date_list.keys()), reverse=True)
sort_array = sorted(filenames, key=lambda x:-x[0])
- if len(sort_array) > loads_num:
- date = sort_array[loads_num][2]
+ if len(sort_array) > batch_size:
+ date = sort_array[batch_size][2]
filenames = [x[1] for x in sort_array]
else:
date = date_to if len(sort_array) == 0 else sort_array[-1][2]
@@ -141,9 +144,9 @@ def archive_images(dir_name, date_to):
filenames = [x[1] for x in sort_array if x[2]>= date]
num = len(filenames)
last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
- date = date[:4] + "-" + date[4:6] + "-" + date[6:8]
- date_to_bak = date_to_bak[:4] + "-" + date_to_bak[4:6] + "-" + date_to_bak[6:8]
- load_info = f"Loaded {(num + 1) // opts.images_history_pages_num} pades, {num} images, during {date} - {date_to_bak}"
+ date = date[:4] + "/" + date[4:6] + "/" + date[6:8]
+ date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8]
+ load_info = f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
_, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
return (
gradio.Dropdown.update(choices=date_list, value=date_to),
@@ -154,12 +157,10 @@ def archive_images(dir_name, date_to):
"",
"",
visible_num,
- last_date_from
+ last_date_from,
+ #gradio.update(visible=batch_count > 1)
)
-
-
-
def delete_image(delete_num, name, filenames, image_index, visible_num):
if name == "":
return filenames, delete_num
@@ -295,16 +296,16 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Column() as page_panel:
with gr.Row():
- img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory")
+ img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
with gr.Row(visible=False) as warning:
warning_box = gr.Textbox("Message", interactive=False)
with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
with gr.Column(scale=2):
- with gr.Row():
- backward = gr.Button('Backward')
- date_to = gr.Dropdown(label="Date to")
- forward = gr.Button('Forward')
+ with gr.Row() as batch_panel:
+ forward = gr.Button('Forward')
+ date_to = gr.Dropdown(label="Date to")
+ backward = gr.Button('Backward')
newest = gr.Button('Reload', elem_id=tabname + "_images_history_start")
with gr.Row():
load_info = gr.Textbox(show_label=False, interactive=False)
@@ -335,7 +336,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
# hiden items
- with gr.Row(): #visible=False):
+ with gr.Row(visible=False):
visible_img_num = gr.Number()
date_to_recorder = gr.State([])
last_date_from = gr.Textbox()
--
cgit v1.2.3
From 538bc89c269743e56b07ef2b471d1ce0a39b6776 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Wed, 19 Oct 2022 11:27:51 +0800
Subject: Image browser improved
---
modules/images_history.py | 135 +++++++++++++++++++++++++---------------------
modules/shared.py | 5 ++
modules/ui.py | 2 +-
3 files changed, 80 insertions(+), 62 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index a40cdc0e..78fd0543 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -4,7 +4,9 @@ import time
import hashlib
import gradio
system_bak_path = "webui_log_and_bak"
-browser_tabname = "custom"
+custom_tab_name = "custom fold"
+faverate_tab_name = "favorites"
+tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name]
def is_valid_date(date):
try:
time.strptime(date, "%Y%m%d")
@@ -122,7 +124,6 @@ def archive_images(dir_name, date_to):
else:
filenames = traverse_all_files(dir_name, filenames)
total_num = len(filenames)
- batch_count = len(filenames) + 1 // batch_size + 1
tmparray = [(os.path.getmtime(file), file) for file in filenames ]
date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
filenames = []
@@ -146,10 +147,12 @@ def archive_images(dir_name, date_to):
last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
date = date[:4] + "/" + date[4:6] + "/" + date[6:8]
date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8]
- load_info = f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
+ load_info = ""
+ load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
+ load_info += "
"
_, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
return (
- gradio.Dropdown.update(choices=date_list, value=date_to),
+ date_to,
load_info,
filenames,
1,
@@ -158,7 +161,7 @@ def archive_images(dir_name, date_to):
"",
visible_num,
last_date_from,
- #gradio.update(visible=batch_count > 1)
+ gradio.update(visible=total_num > num)
)
def delete_image(delete_num, name, filenames, image_index, visible_num):
@@ -209,7 +212,7 @@ def get_recent_images(page_index, step, filenames):
visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
return page_index, image_list, "", "", visible_num
-def newest_click(date_to):
+def loac_batch_click(date_to):
if date_to is None:
return time.strftime("%Y%m%d",time.localtime(time.time())), []
else:
@@ -248,7 +251,7 @@ def page_index_change(page_index, filenames):
def show_image_info(tabname_box, num, page_index, filenames):
file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
- tm = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file)))
+ tm = "" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "
"
return file, tm, num, file
def enable_page_buttons():
@@ -268,9 +271,9 @@ def change_dir(img_dir, date_to):
warning = "The format of the directory is incorrect"
if warning is None:
today = time.strftime("%Y%m%d",time.localtime(time.time()))
- return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today
+ return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True)
else:
- return gradio.update(visible=True), gradio.update(visible=False), warning, date_to
+ return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False)
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
custom_dir = False
@@ -280,7 +283,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
dir_name = opts.outdir_img2img_samples
elif tabname == "extras":
dir_name = opts.outdir_extras_samples
- elif tabname == "saved":
+ elif tabname == faverate_tab_name:
dir_name = opts.outdir_save
else:
custom_dir = True
@@ -295,22 +298,26 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
os.makedirs(dir_name)
with gr.Column() as page_panel:
- with gr.Row():
- img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
+ with gr.Row():
+ with gr.Column(scale=1, visible=not custom_dir) as load_batch_box:
+ load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True)
+ with gr.Column(scale=4):
+ with gr.Row():
+ img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
+ with gr.Row():
+ with gr.Column(visible=False, scale=1) as batch_panel:
+ with gr.Row():
+ forward = gr.Button('Prev batch')
+ backward = gr.Button('Next batch')
+ with gr.Column(scale=3):
+ load_info = gr.HTML(visible=not custom_dir)
with gr.Row(visible=False) as warning:
warning_box = gr.Textbox("Message", interactive=False)
with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
- with gr.Column(scale=2):
- with gr.Row() as batch_panel:
- forward = gr.Button('Forward')
- date_to = gr.Dropdown(label="Date to")
- backward = gr.Button('Backward')
- newest = gr.Button('Reload', elem_id=tabname + "_images_history_start")
- with gr.Row():
- load_info = gr.Textbox(show_label=False, interactive=False)
- with gr.Row(visible=False) as turn_page_buttons:
- renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
+ with gr.Column(scale=2):
+ with gr.Row(visible=True) as turn_page_buttons:
+ #date_to = gr.Dropdown(label="Date to")
first_page = gr.Button('First Page')
prev_page = gr.Button('Prev Page')
page_index = gr.Number(value=1, label="Page Index")
@@ -322,50 +329,54 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
- with gr.Column():
- with gr.Row():
- if tabname != "saved":
- save_btn = gr.Button('Save')
- pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
- pnginfo_send_to_img2img = gr.Button('Send to img2img')
+ with gr.Column():
with gr.Row():
with gr.Column():
- img_file_info = gr.Textbox(label="Generate Info", interactive=False)
+ img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6)
+ gr.HTML(" ")
img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
- img_file_time= gr.Textbox(value="", label="Create Time", interactive=False)
-
+ img_file_time= gr.HTML()
+ with gr.Row():
+ if tabname != faverate_tab_name:
+ save_btn = gr.Button('Collect')
+ pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
+ pnginfo_send_to_img2img = gr.Button('Send to img2img')
+
- # hiden items
- with gr.Row(visible=False):
- visible_img_num = gr.Number()
- date_to_recorder = gr.State([])
- last_date_from = gr.Textbox()
- tabname_box = gr.Textbox(tabname)
- image_index = gr.Textbox(value=-1)
- set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
- filenames = gr.State()
- all_images_list = gr.State()
- hidden = gr.Image(type="pil")
- info1 = gr.Textbox()
- info2 = gr.Textbox()
-
- img_path.submit(change_dir, inputs=[img_path, date_to], outputs=[warning, main_panel, warning_box, date_to])
- #change date
- change_date_output = [date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from]
+ # hiden items
+ with gr.Row(visible=False):
+ renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
+ batch_date_to = gr.Textbox(label="Date to")
+ visible_img_num = gr.Number()
+ date_to_recorder = gr.State([])
+ last_date_from = gr.Textbox()
+ tabname_box = gr.Textbox(tabname)
+ image_index = gr.Textbox(value=-1)
+ set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
+ filenames = gr.State()
+ all_images_list = gr.State()
+ hidden = gr.Image(type="pil")
+ info1 = gr.Textbox()
+ info2 = gr.Textbox()
+
+ img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info])
+
+ #change batch
+ change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel]
- date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output)
- date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
- date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output)
+ batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
+ batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
- newest.click(newest_click, inputs=[date_to], outputs=[date_to, date_to_recorder])
- forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder])
- backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder])
+ load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder])
+ forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
+ backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
#delete
delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None)
- if tabname != "saved":
+ if tabname != faverate_tab_name:
save_btn.click(save_image, inputs=[img_file_name], outputs=None)
#turn page
@@ -394,18 +405,20 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
-def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict):
+def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict):
global opts;
opts = sys_opts
loads_files_num = int(opts.images_history_num_per_page)
num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
+ if cmp_ops.browse_all_images:
+ tabs_list.append(custom_tab_name)
with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs:
- for tab in [browser_tabname, "txt2img", "img2img", "extras", "saved"]:
+ for tab in tabs_list:
with gr.Tab(tab):
with gr.Blocks(analytics_enabled=False) :
- show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
- #gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False)
- gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_finish_render", visible=False)
-
+ show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
+ gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False)
+ gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False)
+
return images_history
diff --git a/modules/shared.py b/modules/shared.py
index 1811018d..4d735414 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -74,6 +74,10 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
+parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False)
+
+
+cmd_opts = parser.parse_args()
cmd_opts = parser.parse_args()
@@ -311,6 +315,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
options_templates.update(options_section(('images-history', "Images Browser"), {
#"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
+ "images_history_preload": OptionInfo(False, "Preload images at startup"),
"images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"),
"images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "),
"images_history_grid_num": OptionInfo(6, "Number of grids in each row"),
diff --git a/modules/ui.py b/modules/ui.py
index 85abac4d..88f46659 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1150,7 +1150,7 @@ def create_ui(wrap_gradio_gpu_call):
"i2i":img2img_paste_fields
}
- images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
+ images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
--
cgit v1.2.3
From abeec4b63029c2c4151a78fc395d312113881845 Mon Sep 17 00:00:00 2001
From: captin411
Date: Wed, 19 Oct 2022 03:18:26 -0700
Subject: Add auto focal point cropping to Preprocess images
This algorithm plots a bunch of points of interest on the source
image and averages their locations to find a center.
Most points come from OpenCV. One point comes from an
entropy model. OpenCV points account for 50% of the weight and the
entropy based point is the other 50%.
The center of all weighted points is calculated and a bounding box
is drawn as close to centered over that point as possible.
---
modules/textual_inversion/preprocess.py | 151 ++++++++++++++++++++++++++++++--
1 file changed, 146 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 886cf0c3..168bfb09 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,5 +1,7 @@
import os
-from PIL import Image, ImageOps
+import cv2
+import numpy as np
+from PIL import Image, ImageOps, ImageDraw
import platform
import sys
import tqdm
@@ -11,7 +13,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False):
try:
if process_caption:
shared.interrogator.load()
@@ -21,7 +23,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
- preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+ preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, process_entropy_focus)
finally:
@@ -33,7 +35,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
-def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -93,6 +95,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
is_tall = ratio > 1.35
is_wide = ratio < 1 / 1.35
+ processing_option_ran = False
+
if process_split and is_tall:
img = img.resize((width, height * img.height // img.width))
@@ -101,6 +105,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
+
+ processing_option_ran = True
elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height))
@@ -109,8 +115,143 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
- else:
+
+ processing_option_ran = True
+
+ if process_entropy_focus and (is_tall or is_wide):
+ if is_tall:
+ img = img.resize((width, height * img.height // img.width))
+ else:
+ img = img.resize((width * img.width // img.height, height))
+
+ x_focal_center, y_focal_center = image_central_focal_point(img, width, height)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(height / 2)
+ x_half = int(width / 2)
+
+ x1 = x_focal_center - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + width > img.width:
+ x1 = img.width - width
+
+ y1 = y_focal_center - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + height > img.height:
+ y1 = img.height - height
+
+ x2 = x1 + width
+ y2 = y1 + height
+
+ crop = [x1, y1, x2, y2]
+
+ focal = img.crop(tuple(crop))
+ save_pic(focal, index)
+
+ processing_option_ran = True
+
+ if not processing_option_ran:
img = images.resize_image(1, img, width, height)
save_pic(img, index)
shared.state.nextjob()
+
+
+def image_central_focal_point(im, target_width, target_height):
+ focal_points = []
+
+ focal_points.extend(
+ image_focal_points(im)
+ )
+
+ fp_entropy = image_entropy_point(im, target_width, target_height)
+ fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy
+
+ focal_points.append(fp_entropy)
+
+ weight = 0.0
+ x = 0.0
+ y = 0.0
+ for focal_point in focal_points:
+ weight += focal_point['weight']
+ x += focal_point['x'] * focal_point['weight']
+ y += focal_point['y'] * focal_point['weight']
+ avg_x = round(x // weight)
+ avg_y = round(y // weight)
+
+ return avg_x, avg_y
+
+
+def image_focal_points(im):
+ grayscale = im.convert("L")
+
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
+ gd = ImageDraw.Draw(grayscale)
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+
+ np_im = np.array(grayscale)
+
+ points = cv2.goodFeaturesToTrack(
+ np_im,
+ maxCorners=50,
+ qualityLevel=0.04,
+ minDistance=min(grayscale.width, grayscale.height)*0.05,
+ useHarrisDetector=False,
+ )
+
+ if points is None:
+ return []
+
+ focal_points = []
+ for point in points:
+ x, y = point.ravel()
+ focal_points.append({
+ 'x': x,
+ 'y': y,
+ 'weight': 1.0
+ })
+
+ return focal_points
+
+
+def image_entropy_point(im, crop_width, crop_height):
+ img = im.copy()
+ # just make it easier to slide the test crop with images oriented the same way
+ if (img.size[0] < img.size[1]):
+ portrait = True
+ img = img.rotate(90, expand=1)
+
+ e_max = 0
+ crop_current = [0, 0, crop_width, crop_height]
+ crop_best = crop_current
+ while crop_current[2] < img.size[0]:
+ crop = img.crop(tuple(crop_current))
+ e = image_entropy(crop)
+
+ if (e_max < e):
+ e_max = e
+ crop_best = list(crop_current)
+
+ crop_current[0] += 4
+ crop_current[2] += 4
+
+ x_mid = int((crop_best[2] - crop_best[0])/2)
+ y_mid = int((crop_best[3] - crop_best[1])/2)
+
+ return {
+ 'x': x_mid,
+ 'y': y_mid,
+ 'weight': 1.0
+ }
+
+
+def image_entropy(im):
+ # greyscale image entropy
+ band = np.asarray(im.convert("L"))
+ hist, _ = np.histogram(band, bins=range(0, 256))
+ hist = hist[hist > 0]
+ return -np.log2(hist / hist.sum()).sum()
+
--
cgit v1.2.3
From 087609ee181a91a523647435ffffa6288a317e2f Mon Sep 17 00:00:00 2001
From: captin411
Date: Wed, 19 Oct 2022 03:19:35 -0700
Subject: UI changes for focal point image cropping
---
modules/ui.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 1ff7eb4f..b6be713b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1234,6 +1234,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images into two')
+ process_entropy_focus = gr.Checkbox(label='Create auto focal point crop')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
@@ -1318,7 +1319,8 @@ def create_ui(wrap_gradio_gpu_call):
process_flip,
process_split,
process_caption,
- process_caption_deepbooru
+ process_caption_deepbooru,
+ process_entropy_focus
],
outputs=[
ti_output,
--
cgit v1.2.3
From 019a3a88f07766f2d32c32fbe8e41625f28ecb5e Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 19 Oct 2022 17:15:47 +0100
Subject: Update ui.py
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index d2e24880..1573ef82 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1247,7 +1247,7 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Tab(label="Train"):
- gr.HTML(value="Train an embedding; must specify a directory with a set of 1:1 ratio images
")
+ gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork wiki
")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
--
cgit v1.2.3
From 2ce52d32e41fb523d1494f45073fd18496e52d35 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Wed, 19 Oct 2022 16:31:12 +0000
Subject: fix for #3086 failing to load any previous hypernet
---
modules/hypernetworks/hypernetwork.py | 60 ++++++++++++++++-------------------
1 file changed, 28 insertions(+), 32 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 7d519cd9..74300122 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -24,11 +24,10 @@ class HypernetworkModule(torch.nn.Module):
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
super().__init__()
- if layer_structure is not None:
- assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
- assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
- else:
- layer_structure = parse_layer_structure(dim, state_dict)
+
+ assert layer_structure is not None, "layer_structure mut not be None"
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
@@ -39,23 +38,30 @@ class HypernetworkModule(torch.nn.Module):
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
- try:
- self.load_state_dict(state_dict)
- except RuntimeError:
- self.try_load_previous(state_dict)
+ self.fix_old_state_dict(state_dict)
+ self.load_state_dict(state_dict)
else:
for layer in self.linear:
- layer.weight.data.normal_(mean = 0.0, std = 0.01)
+ layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
self.to(devices.device)
- def try_load_previous(self, state_dict):
- states = self.state_dict()
- states['linear.0.bias'].copy_(state_dict['linear1.bias'])
- states['linear.0.weight'].copy_(state_dict['linear1.weight'])
- states['linear.1.bias'].copy_(state_dict['linear2.bias'])
- states['linear.1.weight'].copy_(state_dict['linear2.weight'])
+ def fix_old_state_dict(self, state_dict):
+ changes = {
+ 'linear1.bias': 'linear.0.bias',
+ 'linear1.weight': 'linear.0.weight',
+ 'linear2.bias': 'linear.1.bias',
+ 'linear2.weight': 'linear.1.weight',
+ }
+
+ for fr, to in changes.items():
+ x = state_dict.get(fr, None)
+ if x is None:
+ continue
+
+ del state_dict[fr]
+ state_dict[to] = x
def forward(self, x):
return x + self.linear(x) * self.multiplier
@@ -71,18 +77,6 @@ def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
-def parse_layer_structure(dim, state_dict):
- i = 0
- layer_structure = [1]
-
- while (key := "linear.{}.weight".format(i)) in state_dict:
- weight = state_dict[key]
- layer_structure.append(len(weight) // dim)
- i += 1
-
- return layer_structure
-
-
class Hypernetwork:
filename = None
name = None
@@ -135,17 +129,18 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
+
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
- HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
)
self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0)
- self.layer_structure = state_dict.get('layer_structure', None)
- self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
@@ -244,6 +239,7 @@ def stack_conds(conds):
return torch.stack(conds)
+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
--
cgit v1.2.3
From b748b583c0b9f771c1be509175a6913e3f2ad97c Mon Sep 17 00:00:00 2001
From: Mackerel
Date: Wed, 19 Oct 2022 14:22:03 -0400
Subject: generation_parameters_copypaste.py: fix indent
---
modules/generation_parameters_copypaste.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 98d24406..0f041449 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -45,8 +45,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
prompt += ("" if prompt == "" else "\n") + line
- res["Prompt"] = prompt
- res["Negative prompt"] = negative_prompt
+ res["Prompt"] = prompt
+ res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
m = re_imagesize.match(v)
--
cgit v1.2.3
From eb7ba4b713ac2fb960ecf6365b1de0c89451e583 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 19 Oct 2022 19:50:46 +0100
Subject: update training header text
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 1573ef82..93c0767c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1247,7 +1247,7 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Tab(label="Train"):
- gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork wiki
")
+ gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork [wiki]
")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
--
cgit v1.2.3
From 4d663055ded968831ec97f047dfa8e94036cf1c1 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 19 Oct 2022 20:33:18 +0100
Subject: update ui with extra training options
---
modules/ui.py | 11 +++++++++--
1 file changed, 9 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 93c0767c..cdb9d335 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1206,6 +1206,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
with gr.Row():
with gr.Column(scale=3):
@@ -1219,6 +1220,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
with gr.Row():
with gr.Column(scale=3):
@@ -1247,14 +1249,17 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Tab(label="Train"):
- gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork [wiki]
")
+ gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
- learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ with gr.Row():
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
+
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
@@ -1288,6 +1293,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name,
initialization_text,
nvpt,
+ overwrite_old_embedding,
],
outputs=[
train_embedding_name,
@@ -1303,6 +1309,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes,
new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
+ overwrite_old_hypernetwork,
],
outputs=[
train_hypernetwork_name,
--
cgit v1.2.3
From 41e3877be2c667316515c86037413763eb0ba4da Mon Sep 17 00:00:00 2001
From: captin411
Date: Wed, 19 Oct 2022 13:44:59 -0700
Subject: fix entropy point calculation
---
modules/textual_inversion/preprocess.py | 34 ++++++++++++++++++---------------
1 file changed, 19 insertions(+), 15 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 168bfb09..7c1a594e 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -196,9 +196,9 @@ def image_focal_points(im):
points = cv2.goodFeaturesToTrack(
np_im,
- maxCorners=50,
+ maxCorners=100,
qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.05,
+ minDistance=min(grayscale.width, grayscale.height)*0.07,
useHarrisDetector=False,
)
@@ -218,28 +218,32 @@ def image_focal_points(im):
def image_entropy_point(im, crop_width, crop_height):
- img = im.copy()
- # just make it easier to slide the test crop with images oriented the same way
- if (img.size[0] < img.size[1]):
- portrait = True
- img = img.rotate(90, expand=1)
+ landscape = im.height < im.width
+ portrait = im.height > im.width
+ if landscape:
+ move_idx = [0, 2]
+ move_max = im.size[0]
+ elif portrait:
+ move_idx = [1, 3]
+ move_max = im.size[1]
e_max = 0
crop_current = [0, 0, crop_width, crop_height]
crop_best = crop_current
- while crop_current[2] < img.size[0]:
- crop = img.crop(tuple(crop_current))
+ while crop_current[move_idx[1]] < move_max:
+ crop = im.crop(tuple(crop_current))
e = image_entropy(crop)
- if (e_max < e):
+ if (e > e_max):
e_max = e
crop_best = list(crop_current)
- crop_current[0] += 4
- crop_current[2] += 4
+ crop_current[move_idx[0]] += 4
+ crop_current[move_idx[1]] += 4
+
+ x_mid = int(crop_best[0] + crop_width/2)
+ y_mid = int(crop_best[1] + crop_height/2)
- x_mid = int((crop_best[2] - crop_best[0])/2)
- y_mid = int((crop_best[3] - crop_best[1])/2)
return {
'x': x_mid,
@@ -250,7 +254,7 @@ def image_entropy_point(im, crop_width, crop_height):
def image_entropy(im):
# greyscale image entropy
- band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"))
hist, _ = np.histogram(band, bins=range(0, 256))
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
--
cgit v1.2.3
From 8e7097d06a6a261580d34375c9d2a9e4ffc63ffa Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Wed, 19 Oct 2022 13:47:45 -0700
Subject: Added support for RunwayML inpainting model
---
modules/processing.py | 34 ++++++-
modules/sd_hijack_inpainting.py | 208 ++++++++++++++++++++++++++++++++++++++++
modules/sd_models.py | 16 +++-
modules/sd_samplers.py | 50 +++++++---
4 files changed, 293 insertions(+), 15 deletions(-)
create mode 100644 modules/sd_hijack_inpainting.py
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index bcb0c32c..a6c308f9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -546,7 +546,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device)
+ image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=image_conditioning)
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@@ -714,10 +723,31 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
+ if self.image_mask is not None:
+ conditioning_mask = np.array(self.image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ conditioning_mask = conditioning_mask.to(image.device)
+ conditioning_image = image * (1.0 - conditioning_mask)
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
new file mode 100644
index 00000000..7e5670d6
--- /dev/null
+++ b/modules/sd_hijack_inpainting.py
@@ -0,0 +1,208 @@
+import torch
+import numpy as np
+
+from tqdm import tqdm
+from einops import rearrange, repeat
+from omegaconf import ListConfig
+
+from types import MethodType
+
+import ldm.models.diffusion.ddpm
+import ldm.models.diffusion.ddim
+
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+
+# =================================================================================================
+# Monkey patch DDIMSampler methods from RunwayML repo directly.
+# Adapted from:
+# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
+# =================================================================================================
+@torch.no_grad()
+def sample(
+ self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list):
+ ctmp = elf.inpainting_fill == 2:
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
+ elif self.inpainting_fill == 3:
+ self.init_latent = self.init_latent * self.mask
+
+ if self.image_mask is not None:
+ conditioning_mask = np.array(self.image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ conditioning_mask = conditioning_mask.to(image.device)
+ conditioning_image = image * (1.0 - conditioning_mask)
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
+
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ x = create_random_tensors([opctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+
+@torch.no_grad()
+def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
+ for i in range(len(c[k]))
+ ]
+ else:
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+
+# =================================================================================================
+# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
+# Adapted from:
+# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
+# =================================================================================================
+
+@torch.no_grad()
+def get_unconditional_conditioning(self, batch_size, null_label=None):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ # todo: get null label from cond_stage_model
+ raise NotImplementedError()
+ c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
+ return c
+
+class LatentInpaintDiffusion(LatentDiffusion):
+ def __init__(
+ self,
+ concat_keys=("mask", "masked_image"),
+ masked_image_key="masked_image",
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.masked_image_key = masked_image_key
+ assert self.masked_image_key in concat_keys
+ self.concat_keys = concat_keys
+
+def should_hijack_inpainting(checkpoint_info):
+ return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
+
+def do_inpainting_hijack():
+ ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
+ ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
+ ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
+ ldm.models.diffusion.ddim.DDIMSampler.sample = sample
\ No newline at end of file
diff --git a/modules/sd_models.py b/modules/sd_models.py
index eae22e87..47836d25 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices
from modules.paths import models_path
+from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
@@ -211,6 +212,19 @@ def load_model():
print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config)
+
+ if should_hijack_inpainting(checkpoint_info):
+ do_inpainting_hijack()
+
+ # Hardcoded config for now...
+ sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
+ sd_config.model.params.use_ema = False
+ sd_config.model.params.conditioning_key = "hybrid"
+ sd_config.model.params.unet_config.params.in_channels = 9
+
+ # Create a "fake" config with a different name so that we know to unload it when switching models.
+ checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
+
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
@@ -234,7 +248,7 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
+ if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
shared.sd_model = load_model()
return shared.sd_model
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index b58e810b..9d3cf289 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -136,9 +136,15 @@ class VanillaStableDiffusionSampler:
if self.stop_at is not None and self.step > self.stop_at:
raise InterruptedException
+ # Have to unwrap the inpainting conditioning here to perform pre-preocessing
+ image_conditioning = None
+ if isinstance(cond, dict):
+ image_conditioning = cond["c_concat"][0]
+ cond = cond["c_crossattn"][0]
+ unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
@@ -157,6 +163,10 @@ class VanillaStableDiffusionSampler:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
+ if image_conditioning is not None:
+ cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None:
@@ -182,7 +192,7 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
self.initialize(p)
@@ -202,7 +212,7 @@ class VanillaStableDiffusionSampler:
return samples
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p)
self.init_latent = None
@@ -210,6 +220,11 @@ class VanillaStableDiffusionSampler:
steps = steps or p.steps
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
# existing code fails with certain step counts, like 9
try:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
@@ -228,7 +243,7 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
- def forward(self, x, sigma, uncond, cond, cond_scale):
+ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
@@ -239,28 +254,29 @@ class CFGDenoiser(torch.nn.Module):
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
+ x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
@@ -361,7 +377,7 @@ class KDiffusionSampler:
return extra_params_kwargs
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
if p.sampler_noise_scheduler_override:
@@ -389,11 +405,16 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
if p.sampler_noise_scheduler_override:
@@ -414,7 +435,12 @@ class KDiffusionSampler:
else:
extra_params_kwargs['sigmas'] = sigmas
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
--
cgit v1.2.3
From 0719c10bf1b817364a498ee11b90d30d3d527344 Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Wed, 19 Oct 2022 13:56:26 -0700
Subject: Fixed copying mistake
---
modules/sd_hijack_inpainting.py | 79 +++++++++++++----------------------------
1 file changed, 25 insertions(+), 54 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 7e5670d6..d4d28d2e 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -19,63 +19,35 @@ from ldm.models.diffusion.ddim import DDIMSampler, noise_like
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
# =================================================================================================
@torch.no_grad()
-def sample(
- self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
+def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
- ctmp = elf.inpainting_fill == 2:
- self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
- elif self.inpainting_fill == 3:
- self.init_latent = self.init_latent * self.mask
-
- if self.image_mask is not None:
- conditioning_mask = np.array(self.image_mask.convert("L"))
- conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
- conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
-
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
- else:
- conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
-
- # Create another latent image, this time with a masked version of the original input.
- conditioning_mask = conditioning_mask.to(image.device)
- conditioning_image = image * (1.0 - conditioning_mask)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
-
- # Create the concatenated conditioning tensor to be fed to `c_concat`
- conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
- conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
- self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
- self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
- x = create_random_tensors([opctmp[0]
+ ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
@@ -106,7 +78,6 @@ def sample(
)
return samples, intermediates
-
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
--
cgit v1.2.3
From dde9f960727bfe151d418e43685a2881cf580a17 Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Wed, 19 Oct 2022 14:14:24 -0700
Subject: added support for ddim img2img
---
modules/sd_samplers.py | 6 ++++++
1 file changed, 6 insertions(+)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 9d3cf289..d270e4df 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -208,6 +208,12 @@ class VanillaStableDiffusionSampler:
self.init_latent = x
self.step = 0
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
+
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
--
cgit v1.2.3
From c418467c03db916c3e5312e6ac4a67365e196dbd Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Wed, 19 Oct 2022 15:09:43 -0700
Subject: Don't compute latent mask if were not using it. Also added support
for fixed highres_fix generation.
---
modules/processing.py | 72 +++++++++++++++++++++++++++++++-------------------
modules/sd_samplers.py | 4 +++
2 files changed, 49 insertions(+), 27 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index a6c308f9..684e5833 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -541,12 +541,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
- self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
-
- if not self.enable_hr:
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
-
+ def create_dummy_mask(self, x):
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
@@ -555,11 +551,23 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=image_conditioning)
+ else:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ image_conditioning = torch.zeros(x.shape[0], 5, x.shape[-2], x.shape[-1], dtype=x.dtype, device=x.device)
+
+ return image_conditioning
+
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
+
+ if not self.enable_hr:
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
@@ -596,7 +604,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
return samples
@@ -723,26 +731,36 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- if self.image_mask is not None:
- conditioning_mask = np.array(self.image_mask.convert("L"))
- conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
- conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+ conditioning_key = self.sampler.conditioning_key
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
+ if conditioning_key in {'hybrid', 'concat'}:
+ if self.image_mask is not None:
+ conditioning_mask = np.array(self.image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ conditioning_mask = conditioning_mask.to(image.device)
+ conditioning_image = image * (1.0 - conditioning_mask)
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
else:
- conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
-
- # Create another latent image, this time with a masked version of the original input.
- conditioning_mask = conditioning_mask.to(image.device)
- conditioning_image = image * (1.0 - conditioning_mask)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
-
- # Create the concatenated conditioning tensor to be fed to `c_concat`
- conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
- conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
- self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
- self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
+ self.image_conditioning = torch.zeros(
+ self.init_latent.shape[0], 5, self.init_latent.shape[-2], self.init_latent.shape[-1],
+ dtype=self.init_latent.dtype,
+ device=self.init_latent.device
+ )
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index d270e4df..c21be26e 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler:
self.config = None
self.last_latent = None
+ self.conditioning_key = sd_model.model.conditioning_key
+
def number_of_needed_noises(self, p):
return 0
@@ -328,6 +330,8 @@ class KDiffusionSampler:
self.config = None
self.last_latent = None
+ self.conditioning_key = sd_model.model.conditioning_key
+
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
--
cgit v1.2.3
From d6ea5841374a28f3f6deb73abc251c8f0bcb240f Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:07:57 +0100
Subject: change html output
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 7d519cd9..73c1cb80 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -380,7 +380,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
Loss: {mean_loss:.7f}
Step: {hypernetwork.step}
Last prompt: {html.escape(entries[0].cond_text)}
-Last saved embedding: {html.escape(last_saved_file)}
+Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
--
cgit v1.2.3
From 166be3919b817cee5e702fd01c34afe9081b952c Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:09:40 +0100
Subject: allow overwrite old hn
---
modules/hypernetworks/ui.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 08f75f15..f45345ea 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -10,9 +10,10 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure)))
--
cgit v1.2.3
From 0087079c2d487b67b06ffc30f36ce486a74e6318 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:10:59 +0100
Subject: allow overwrite old embedding
---
modules/textual_inversion/textual_inversion.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 3be69562..5776778b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -153,7 +153,7 @@ class EmbeddingDatabase:
return None, None
-def create_embedding(name, num_vectors_per_token, init_text='*'):
+def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
--
cgit v1.2.3
From 632e8d660293081cadb145d8062e5aff0a4a8f0d Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:19:40 +0100
Subject: split learn rates
---
modules/ui.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index cdb9d335..d07184ee 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1342,7 +1342,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
- learn_rate,
+ embedding_learn_rate,
batch_size,
dataset_directory,
log_directory,
@@ -1367,7 +1367,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
- learn_rate,
+ hypernetwork_learn_rate,
batch_size,
dataset_directory,
log_directory,
--
cgit v1.2.3
From c3835ec85cbb44fa3c46fa871c622b6fee235c89 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:24:24 +0100
Subject: pass overwrite old flag
---
modules/textual_inversion/ui.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index 36881e7a..e712284d 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
-def create_embedding(name, initialization_text, nvpt):
- filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
+def create_embedding(name, initialization_text, nvpt, overwrite_old):
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
--
cgit v1.2.3
From 4d6b9f76a55fd0ac0f72634071032dd9c6efb409 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:27:16 +0100
Subject: reorder create_hypernetwork params
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index d07184ee..322c082b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1307,9 +1307,9 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
+ overwrite_old_hypernetwork,
new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
- overwrite_old_hypernetwork,
],
outputs=[
train_hypernetwork_name,
--
cgit v1.2.3
From fbcce66601994f6ed370db36d9c238840fed6bd2 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:46:54 +0100
Subject: add existing caption file handling
---
modules/textual_inversion/preprocess.py | 32 ++++++++++++++++++++++++--------
1 file changed, 24 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 886cf0c3..5c43fe13 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- def save_pic_with_caption(image, index):
+ def save_pic_with_caption(image, index, existing_caption=None):
caption = ""
if process_caption:
@@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))
+ if preprocess_txt_action == 'prepend' and existing_caption:
+ caption = existing_caption + ' ' + caption
+ elif preprocess_txt_action == 'append' and existing_caption:
+ caption = caption + ' ' + existing_caption
+ elif preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
+
+ caption = caption.strip()
+
if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
subindex[0] += 1
- def save_pic(image, index):
+ def save_pic(image, index, existing_caption=None):
save_pic_with_caption(image, index)
if process_flip:
- save_pic_with_caption(ImageOps.mirror(image), index)
+ save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
@@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
except Exception:
continue
+ existing_caption = None
+
+ try:
+ existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
+ except Exception as e:
+ print(e)
+
if shared.state.interrupted:
break
@@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
img = img.resize((width, height * img.height // img.width))
top = img.crop((0, 0, width, height))
- save_pic(top, index)
+ save_pic(top, index, existing_caption=existing_caption)
bot = img.crop((0, img.height - height, width, img.height))
- save_pic(bot, index)
+ save_pic(bot, index, existing_caption=existing_caption)
elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height))
left = img.crop((0, 0, width, height))
- save_pic(left, index)
+ save_pic(left, index, existing_caption=existing_caption)
right = img.crop((img.width - width, 0, img.width, height))
- save_pic(right, index)
+ save_pic(right, index, existing_caption=existing_caption)
else:
img = images.resize_image(1, img, width, height)
- save_pic(img, index)
+ save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob()
--
cgit v1.2.3
From ab353b141df8eee042b0964bcb645015dabf3459 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:48:07 +0100
Subject: link existing txt option
---
modules/ui.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 322c082b..7f52ac0c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1234,6 +1234,7 @@ def create_ui(wrap_gradio_gpu_call):
process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', choices=['ignore', 'copy', 'prepend', 'append'])
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
@@ -1326,6 +1327,7 @@ def create_ui(wrap_gradio_gpu_call):
process_dst,
process_width,
process_height,
+ preprocess_txt_action,
process_flip,
process_split,
process_caption,
--
cgit v1.2.3
From 9b65c4ecf4f8eb6187ee721918adebe68e9bc631 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:49:23 +0100
Subject: pass preprocess_txt_action param
---
modules/textual_inversion/preprocess.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 5c43fe13..3713bc89 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
try:
if process_caption:
shared.interrogator.load()
@@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
- preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru)
finally:
@@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
-def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
--
cgit v1.2.3
From 55d8c6cce6d3aef848b9f194adad2ce53064d8b7 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:53:29 +0100
Subject: default to ignore existing captions
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 7f52ac0c..bd5f1b05 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1234,7 +1234,7 @@ def create_ui(wrap_gradio_gpu_call):
process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
- preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', choices=['ignore', 'copy', 'prepend', 'append'])
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
--
cgit v1.2.3
From 8b74b9aa9a20e4c5c1f72641f8b9617479eb276b Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Wed, 19 Oct 2022 19:06:14 -0500
Subject: add symbol for clear button and simplify roll_col css selector
---
modules/ui.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index a2dbd41e..9f6edc5f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -83,6 +83,7 @@ folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
+trash_prompt_symbol = '\U0001F5D1' # 🗑🗑🗑
def plaintext_to_html(text):
@@ -498,6 +499,7 @@ def create_toprow(is_img2img):
paste = gr.Button(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
+ trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt")
token_counter = gr.HTML(value=" ", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
--
cgit v1.2.3
From 6f98e89486f55b0e4657e96ce640cf1c4675d187 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Thu, 20 Oct 2022 00:10:45 +0000
Subject: update
---
modules/hypernetworks/hypernetwork.py | 29 +++++++++++++++--------
modules/hypernetworks/ui.py | 3 ++-
modules/ui.py | 43 +++++++++++++++++++----------------
3 files changed, 44 insertions(+), 31 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 74300122..7d617680 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -22,16 +22,20 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
- def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
+ def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
super().__init__()
- assert layer_structure is not None, "layer_structure mut not be None"
+ assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+ if activation_func == "relu":
+ linears.append(torch.nn.ReLU())
+ if activation_func == "leakyrelu":
+ linears.append(torch.nn.LeakyReLU())
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -42,8 +46,9 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- layer.weight.data.normal_(mean=0.0, std=0.01)
- layer.bias.data.zero_()
+ if not "ReLU" in layer.__str__():
+ layer.weight.data.normal_(mean=0.0, std=0.01)
+ layer.bias.data.zero_()
self.to(devices.device)
@@ -69,7 +74,8 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- layer_structure += [layer.weight, layer.bias]
+ if not "ReLU" in layer.__str__():
+ layer_structure += [layer.weight, layer.bias]
return layer_structure
@@ -81,7 +87,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
self.filename = None
self.name = name
self.layers = {}
@@ -90,11 +96,12 @@ class Hypernetwork:
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
self.add_layer_norm = add_layer_norm
+ self.activation_func = activation_func
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
+ HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
)
def weights(self):
@@ -117,6 +124,7 @@ class Hypernetwork:
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['activation_func'] = self.activation_func
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -131,12 +139,13 @@ class Hypernetwork:
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ self.activation_func = state_dict.get('activation_func', None)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
- HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
)
self.name = state_dict.get('name', self.name)
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 08f75f15..83f9547b 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -10,7 +10,7 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
+def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
@@ -22,6 +22,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
add_layer_norm=add_layer_norm,
+ activation_func=activation_func,
)
hypernet.save(fn)
diff --git a/modules/ui.py b/modules/ui.py
index d2e24880..8751fa9c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -5,43 +5,44 @@ import json
import math
import mimetypes
import os
+import platform
import random
+import subprocess as sp
import sys
import tempfile
import time
import traceback
-import platform
-import subprocess as sp
from functools import partial, reduce
+import gradio as gr
+import gradio.routes
+import gradio.utils
import numpy as np
+import piexif
import torch
from PIL import Image, PngImagePlugin
-import piexif
-import gradio as gr
-import gradio.utils
-import gradio.routes
-
-from modules import sd_hijack, sd_models, localization
+from modules import localization, sd_hijack, sd_models
from modules.paths import script_path
-from modules.shared import opts, cmd_opts, restricted_opts
+from modules.shared import cmd_opts, opts, restricted_opts
+
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
-import modules.shared as shared
-from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.sd_hijack import model_hijack
+
+import modules.codeformer_model
+import modules.generation_parameters_copypaste
+import modules.gfpgan_model
+import modules.hypernetworks.ui
+import modules.images_history as img_his
import modules.ldsr_model
import modules.scripts
-import modules.gfpgan_model
-import modules.codeformer_model
+import modules.shared as shared
import modules.styles
-import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
from modules import prompt_parser
from modules.images import save_image
-import modules.textual_inversion.ui
-import modules.hypernetworks.ui
-import modules.images_history as img_his
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -268,8 +269,8 @@ def calc_time_left(progress, threshold, label, force_display):
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
- if (eta_relative > threshold and progress > 0.02) or force_display:
- return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
+ if (eta_relative > threshold and progress > 0.02) or force_display:
+ return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
else:
return ""
@@ -1219,6 +1220,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["relu", "leakyrelu"])
with gr.Row():
with gr.Column(scale=3):
@@ -1303,6 +1305,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes,
new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
+ new_hypernetwork_activation_func,
],
outputs=[
train_hypernetwork_name,
--
cgit v1.2.3
From ba469343e6a1c6e23e82acf5feb65c6101dacbb2 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Thu, 20 Oct 2022 00:17:04 +0000
Subject: align ui.py imports with upstream
---
modules/ui.py | 37 ++++++++++++++++++-------------------
1 file changed, 18 insertions(+), 19 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 987b1d7d..913b23b4 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -5,44 +5,43 @@ import json
import math
import mimetypes
import os
-import platform
import random
-import subprocess as sp
import sys
import tempfile
import time
import traceback
+import platform
+import subprocess as sp
from functools import partial, reduce
-import gradio as gr
-import gradio.routes
-import gradio.utils
import numpy as np
-import piexif
import torch
from PIL import Image, PngImagePlugin
+import piexif
-from modules import localization, sd_hijack, sd_models
-from modules.paths import script_path
-from modules.shared import cmd_opts, opts, restricted_opts
+import gradio as gr
+import gradio.utils
+import gradio.routes
+from modules import sd_hijack, sd_models, localization
+from modules.paths import script_path
+from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
-
-import modules.codeformer_model
-import modules.generation_parameters_copypaste
-import modules.gfpgan_model
-import modules.hypernetworks.ui
-import modules.images_history as img_his
+import modules.shared as shared
+from modules.sd_samplers import samplers, samplers_for_img2img
+from modules.sd_hijack import model_hijack
import modules.ldsr_model
import modules.scripts
-import modules.shared as shared
+import modules.gfpgan_model
+import modules.codeformer_model
import modules.styles
-import modules.textual_inversion.ui
+import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
-from modules.sd_hijack import model_hijack
-from modules.sd_samplers import samplers, samplers_for_img2img
+import modules.textual_inversion.ui
+import modules.hypernetworks.ui
+import modules.images_history as img_his
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
--
cgit v1.2.3
From 59ed74438318af893d2cba552b0e28dbc2a9266c Mon Sep 17 00:00:00 2001
From: captin411
Date: Wed, 19 Oct 2022 17:19:02 -0700
Subject: face detection algo, configurability, reusability
Try to move the crop in the direction of a face if it is present
More internal configuration options for choosing weights of each of the algorithm's findings
Move logic into its module
---
modules/textual_inversion/autocrop.py | 216 ++++++++++++++++++++++++++++++++
modules/textual_inversion/preprocess.py | 150 +++-------------------
2 files changed, 230 insertions(+), 136 deletions(-)
create mode 100644 modules/textual_inversion/autocrop.py
(limited to 'modules')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
new file mode 100644
index 00000000..f858a958
--- /dev/null
+++ b/modules/textual_inversion/autocrop.py
@@ -0,0 +1,216 @@
+import cv2
+from collections import defaultdict
+from math import log, sqrt
+import numpy as np
+from PIL import Image, ImageDraw
+
+GREEN = "#0F0"
+BLUE = "#00F"
+RED = "#F00"
+
+def crop_image(im, settings):
+ """ Intelligently crop an image to the subject matter """
+ if im.height > im.width:
+ im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width))
+ else:
+ im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height))
+
+ focus = focal_point(im, settings)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(settings.crop_height / 2)
+ x_half = int(settings.crop_width / 2)
+
+ x1 = focus.x - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + settings.crop_width > im.width:
+ x1 = im.width - settings.crop_width
+
+ y1 = focus.y - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + settings.crop_height > im.height:
+ y1 = im.height - settings.crop_height
+
+ x2 = x1 + settings.crop_width
+ y2 = y1 + settings.crop_height
+
+ crop = [x1, y1, x2, y2]
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ rect = list(crop)
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ if settings.destop_view_image:
+ im.show()
+
+ return im.crop(tuple(crop))
+
+def focal_point(im, settings):
+ corner_points = image_corner_points(im, settings)
+ entropy_points = image_entropy_points(im, settings)
+ face_points = image_face_points(im, settings)
+
+ total_points = len(corner_points) + len(entropy_points) + len(face_points)
+
+ corner_weight = settings.corner_points_weight
+ entropy_weight = settings.entropy_points_weight
+ face_weight = settings.face_points_weight
+
+ weight_pref_total = corner_weight + entropy_weight + face_weight
+
+ # weight things
+ pois = []
+ if weight_pref_total == 0 or total_points == 0:
+ return pois
+
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ]
+ )
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ]
+ )
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ]
+ )
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+
+ average_point = poi_average(pois, settings, im=im)
+
+ if settings.annotate_image:
+ d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN)
+
+ return average_point
+
+
+def image_face_points(im, settings):
+ np_im = np.array(im)
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+ classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml')
+
+ minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.05,
+ minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+
+ if len(faces) == 0:
+ return []
+
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ if settings.annotate_image:
+ for f in rects:
+ d = ImageDraw.Draw(im)
+ d.rectangle(f, outline=RED)
+
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects]
+
+
+def image_corner_points(im, settings):
+ grayscale = im.convert("L")
+
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
+ gd = ImageDraw.Draw(grayscale)
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+
+ np_im = np.array(grayscale)
+
+ points = cv2.goodFeaturesToTrack(
+ np_im,
+ maxCorners=100,
+ qualityLevel=0.04,
+ minDistance=min(grayscale.width, grayscale.height)*0.07,
+ useHarrisDetector=False,
+ )
+
+ if points is None:
+ return []
+
+ focal_points = []
+ for point in points:
+ x, y = point.ravel()
+ focal_points.append(PointOfInterest(x, y))
+
+ return focal_points
+
+
+def image_entropy_points(im, settings):
+ landscape = im.height < im.width
+ portrait = im.height > im.width
+ if landscape:
+ move_idx = [0, 2]
+ move_max = im.size[0]
+ elif portrait:
+ move_idx = [1, 3]
+ move_max = im.size[1]
+ else:
+ return []
+
+ e_max = 0
+ crop_current = [0, 0, settings.crop_width, settings.crop_height]
+ crop_best = crop_current
+ while crop_current[move_idx[1]] < move_max:
+ crop = im.crop(tuple(crop_current))
+ e = image_entropy(crop)
+
+ if (e > e_max):
+ e_max = e
+ crop_best = list(crop_current)
+
+ crop_current[move_idx[0]] += 4
+ crop_current[move_idx[1]] += 4
+
+ x_mid = int(crop_best[0] + settings.crop_width/2)
+ y_mid = int(crop_best[1] + settings.crop_height/2)
+
+ return [PointOfInterest(x_mid, y_mid)]
+
+
+def image_entropy(im):
+ # greyscale image entropy
+ band = np.asarray(im.convert("1"))
+ hist, _ = np.histogram(band, bins=range(0, 256))
+ hist = hist[hist > 0]
+ return -np.log2(hist / hist.sum()).sum()
+
+
+def poi_average(pois, settings, im=None):
+ weight = 0.0
+ x = 0.0
+ y = 0.0
+ for pois in pois:
+ if settings.annotate_image and im is not None:
+ w = 4 * 0.5 * sqrt(pois.weight)
+ d = ImageDraw.Draw(im)
+ d.ellipse([
+ pois.x - w, pois.y - w,
+ pois.x + w, pois.y + w ], fill=BLUE)
+ weight += pois.weight
+ x += pois.x * pois.weight
+ y += pois.y * pois.weight
+ avg_x = round(x / weight)
+ avg_y = round(y / weight)
+
+ return PointOfInterest(avg_x, avg_y)
+
+
+class PointOfInterest:
+ def __init__(self, x, y, weight=1.0):
+ self.x = x
+ self.y = y
+ self.weight = weight
+
+
+class Settings:
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False):
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.corner_points_weight = corner_points_weight
+ self.entropy_points_weight = entropy_points_weight
+ self.face_points_weight = entropy_points_weight
+ self.annotate_image = annotate_image
+ self.destop_view_image = False
\ No newline at end of file
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 7c1a594e..0c79f012 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,7 +1,5 @@
import os
-import cv2
-import numpy as np
-from PIL import Image, ImageOps, ImageDraw
+from PIL import Image, ImageOps
import platform
import sys
import tqdm
@@ -9,6 +7,7 @@ import time
from modules import shared, images
from modules.shared import opts, cmd_opts
+from modules.textual_inversion import autocrop
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
@@ -80,6 +79,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index)
+
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
@@ -118,37 +118,16 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
processing_option_ran = True
- if process_entropy_focus and (is_tall or is_wide):
- if is_tall:
- img = img.resize((width, height * img.height // img.width))
- else:
- img = img.resize((width * img.width // img.height, height))
-
- x_focal_center, y_focal_center = image_central_focal_point(img, width, height)
-
- # take the focal point and turn it into crop coordinates that try to center over the focal
- # point but then get adjusted back into the frame
- y_half = int(height / 2)
- x_half = int(width / 2)
-
- x1 = x_focal_center - x_half
- if x1 < 0:
- x1 = 0
- elif x1 + width > img.width:
- x1 = img.width - width
-
- y1 = y_focal_center - y_half
- if y1 < 0:
- y1 = 0
- elif y1 + height > img.height:
- y1 = img.height - height
-
- x2 = x1 + width
- y2 = y1 + height
-
- crop = [x1, y1, x2, y2]
-
- focal = img.crop(tuple(crop))
+ if process_entropy_focus and img.height != img.width:
+ autocrop_settings = autocrop.Settings(
+ crop_width = width,
+ crop_height = height,
+ face_points_weight = 0.9,
+ entropy_points_weight = 0.7,
+ corner_points_weight = 0.5,
+ annotate_image = False
+ )
+ focal = autocrop.crop_image(img, autocrop_settings)
save_pic(focal, index)
processing_option_ran = True
@@ -157,105 +136,4 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
img = images.resize_image(1, img, width, height)
save_pic(img, index)
- shared.state.nextjob()
-
-
-def image_central_focal_point(im, target_width, target_height):
- focal_points = []
-
- focal_points.extend(
- image_focal_points(im)
- )
-
- fp_entropy = image_entropy_point(im, target_width, target_height)
- fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy
-
- focal_points.append(fp_entropy)
-
- weight = 0.0
- x = 0.0
- y = 0.0
- for focal_point in focal_points:
- weight += focal_point['weight']
- x += focal_point['x'] * focal_point['weight']
- y += focal_point['y'] * focal_point['weight']
- avg_x = round(x // weight)
- avg_y = round(y // weight)
-
- return avg_x, avg_y
-
-
-def image_focal_points(im):
- grayscale = im.convert("L")
-
- # naive attempt at preventing focal points from collecting at watermarks near the bottom
- gd = ImageDraw.Draw(grayscale)
- gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
-
- np_im = np.array(grayscale)
-
- points = cv2.goodFeaturesToTrack(
- np_im,
- maxCorners=100,
- qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.07,
- useHarrisDetector=False,
- )
-
- if points is None:
- return []
-
- focal_points = []
- for point in points:
- x, y = point.ravel()
- focal_points.append({
- 'x': x,
- 'y': y,
- 'weight': 1.0
- })
-
- return focal_points
-
-
-def image_entropy_point(im, crop_width, crop_height):
- landscape = im.height < im.width
- portrait = im.height > im.width
- if landscape:
- move_idx = [0, 2]
- move_max = im.size[0]
- elif portrait:
- move_idx = [1, 3]
- move_max = im.size[1]
-
- e_max = 0
- crop_current = [0, 0, crop_width, crop_height]
- crop_best = crop_current
- while crop_current[move_idx[1]] < move_max:
- crop = im.crop(tuple(crop_current))
- e = image_entropy(crop)
-
- if (e > e_max):
- e_max = e
- crop_best = list(crop_current)
-
- crop_current[move_idx[0]] += 4
- crop_current[move_idx[1]] += 4
-
- x_mid = int(crop_best[0] + crop_width/2)
- y_mid = int(crop_best[1] + crop_height/2)
-
-
- return {
- 'x': x_mid,
- 'y': y_mid,
- 'weight': 1.0
- }
-
-
-def image_entropy(im):
- # greyscale image entropy
- band = np.asarray(im.convert("1"))
- hist, _ = np.histogram(band, bins=range(0, 256))
- hist = hist[hist > 0]
- return -np.log2(hist / hist.sum()).sum()
-
+ shared.state.nextjob()
\ No newline at end of file
--
cgit v1.2.3
From 858462f719c22ca9f24b94a41699653c34b5f4fb Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 02:57:18 +0100
Subject: do caption copy for both flips
---
modules/textual_inversion/preprocess.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 3713bc89..6bba3852 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -82,7 +82,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
subindex[0] += 1
def save_pic(image, index, existing_caption=None):
- save_pic_with_caption(image, index)
+ save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
--
cgit v1.2.3
From c6345bd445463b7aa41723d6637e80dfa293a890 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Wed, 19 Oct 2022 21:23:57 -0500
Subject: nerf line length
---
modules/ui.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 9f6edc5f..cb9a6c6e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -83,7 +83,7 @@ folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
-trash_prompt_symbol = '\U0001F5D1' # 🗑🗑🗑
+trash_prompt_symbol = '\U0001F5D1' #
def plaintext_to_html(text):
@@ -617,7 +617,10 @@ def create_ui(wrap_gradio_gpu_call):
return refresh_button
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\
+ txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\
+ token_button = create_toprow(is_img2img=False)
+
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
--
cgit v1.2.3
From aa7ff2a1972f3865883e10ba28c5414cdebe8e3b Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Wed, 19 Oct 2022 21:46:13 -0700
Subject: Fixed non-square highres fix generation
---
modules/processing.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 684e5833..3caac25e 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -541,10 +541,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def create_dummy_mask(self, x):
+ def create_dummy_mask(self, x, first_phase: bool = False):
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ height = self.firstphase_height if first_phase else self.height
+ width = self.firstphase_width if first_phase else self.width
+
# The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device)
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
@@ -567,7 +570,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, first_phase=True))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
--
cgit v1.2.3
From 930b4c64f7dbce6918894d53538003e5959fd022 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 20 Oct 2022 08:18:02 +0300
Subject: allow float sizes for hypernet's layer_structure
---
modules/hypernetworks/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 08f75f15..e0741d08 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -15,7 +15,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
- layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure)))
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name,
--
cgit v1.2.3
From 158d678f596d7fc304a6ce2f0dc31f8abfe62250 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Thu, 20 Oct 2022 01:08:24 -0500
Subject: clear prompt button now works on both relevant tabs. Device detection
stuff will be added later.
---
modules/ui.py | 21 ++++++++++++++++++---
1 file changed, 18 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index cb9a6c6e..bde546cc 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -424,6 +424,16 @@ def create_seed_inputs():
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
+# setup button for clearing prompt input boxes on client side of webui
+def connect_trash_prompt(dummy_component, button, is_img2img):
+
+ button.click(
+ fn=lambda: print("Clearing prompt"),
+ _js="trash_prompt",
+ inputs=[],
+ outputs=[],
+ )
+
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
@@ -540,7 +550,7 @@ def create_toprow(is_img2img):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
prompt_style2.save_to_config = True
- return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, trash_prompt
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@@ -619,10 +629,11 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\
txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\
- token_button = create_toprow(is_img2img=False)
+ token_button, trash_prompt_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
+ connect_trash_prompt(dummy_component, trash_prompt_button, False)
with gr.Row(elem_id='txt2img_progress_row'):
with gr.Column(scale=1):
@@ -807,7 +818,11 @@ def create_ui(wrap_gradio_gpu_call):
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\
+ img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\
+ token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True)
+
+ connect_trash_prompt(dummy_component,trash_prompt_button, True)
with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
--
cgit v1.2.3
From 0ddaf8d2028a7251e8c4ad93551a43b5d4700841 Mon Sep 17 00:00:00 2001
From: captin411
Date: Thu, 20 Oct 2022 00:34:55 -0700
Subject: improve face detection a lot
---
modules/textual_inversion/autocrop.py | 99 ++++++++++++++++++++++-------------
1 file changed, 62 insertions(+), 37 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index f858a958..5a551c25 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -8,12 +8,18 @@ GREEN = "#0F0"
BLUE = "#00F"
RED = "#F00"
+
def crop_image(im, settings):
""" Intelligently crop an image to the subject matter """
if im.height > im.width:
im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width))
- else:
+ elif im.width > im.height:
im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height))
+ else:
+ im = im.resize((settings.crop_width, settings.crop_height))
+
+ if im.height == im.width:
+ return im
focus = focal_point(im, settings)
@@ -78,13 +84,18 @@ def focal_point(im, settings):
[ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ]
)
- if settings.annotate_image:
- d = ImageDraw.Draw(im)
-
- average_point = poi_average(pois, settings, im=im)
+ average_point = poi_average(pois, settings)
if settings.annotate_image:
- d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN)
+ d = ImageDraw.Draw(im)
+ for f in face_points:
+ d.rectangle(f.bounding(f.size), outline=RED)
+ for f in entropy_points:
+ d.rectangle(f.bounding(30), outline=BLUE)
+ for poi in pois:
+ w = max(4, 4 * 0.5 * sqrt(poi.weight))
+ d.ellipse(poi.bounding(w), fill=BLUE)
+ d.ellipse(average_point.bounding(25), outline=GREEN)
return average_point
@@ -92,22 +103,32 @@ def focal_point(im, settings):
def image_face_points(im, settings):
np_im = np.array(im)
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
- classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml')
-
- minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side
- faces = classifier.detectMultiScale(gray, scaleFactor=1.05,
- minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- if len(faces) == 0:
- return []
-
- rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
- if settings.annotate_image:
- for f in rects:
- d = ImageDraw.Draw(im)
- d.rectangle(f, outline=RED)
-
- return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects]
+ tries = [
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
+ ]
+
+ for t in tries:
+ # print(t[0])
+ classifier = cv2.CascadeClassifier(t[0])
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+ try:
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+ except:
+ continue
+
+ if len(faces) > 0:
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects]
+ return []
def image_corner_points(im, settings):
@@ -132,8 +153,8 @@ def image_corner_points(im, settings):
focal_points = []
for point in points:
- x, y = point.ravel()
- focal_points.append(PointOfInterest(x, y))
+ x, y = point.ravel()
+ focal_points.append(PointOfInterest(x, y, size=4))
return focal_points
@@ -167,31 +188,26 @@ def image_entropy_points(im, settings):
x_mid = int(crop_best[0] + settings.crop_width/2)
y_mid = int(crop_best[1] + settings.crop_height/2)
- return [PointOfInterest(x_mid, y_mid)]
+ return [PointOfInterest(x_mid, y_mid, size=25)]
def image_entropy(im):
# greyscale image entropy
- band = np.asarray(im.convert("1"))
+ # band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
hist, _ = np.histogram(band, bins=range(0, 256))
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
-def poi_average(pois, settings, im=None):
+def poi_average(pois, settings):
weight = 0.0
x = 0.0
y = 0.0
- for pois in pois:
- if settings.annotate_image and im is not None:
- w = 4 * 0.5 * sqrt(pois.weight)
- d = ImageDraw.Draw(im)
- d.ellipse([
- pois.x - w, pois.y - w,
- pois.x + w, pois.y + w ], fill=BLUE)
- weight += pois.weight
- x += pois.x * pois.weight
- y += pois.y * pois.weight
+ for poi in pois:
+ weight += poi.weight
+ x += poi.x * poi.weight
+ y += poi.y * poi.weight
avg_x = round(x / weight)
avg_y = round(y / weight)
@@ -199,10 +215,19 @@ def poi_average(pois, settings, im=None):
class PointOfInterest:
- def __init__(self, x, y, weight=1.0):
+ def __init__(self, x, y, weight=1.0, size=10):
self.x = x
self.y = y
self.weight = weight
+ self.size = size
+
+ def bounding(self, size):
+ return [
+ self.x - size//2,
+ self.y - size//2,
+ self.x + size//2,
+ self.y + size//2
+ ]
class Settings:
--
cgit v1.2.3
From f8733ad08be08bafb40f4299785590e11f049e96 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Thu, 20 Oct 2022 11:07:37 +0000
Subject: add linear as a act func (option for doin nothing)
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 913b23b4..716f14b8 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1224,7 +1224,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["relu", "leakyrelu"])
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row():
with gr.Column(scale=3):
--
cgit v1.2.3
From 9681419e422515e42444e0174355b760645a846f Mon Sep 17 00:00:00 2001
From: Milly
Date: Thu, 20 Oct 2022 16:53:46 +0900
Subject: train: fixed preprocess image ratio
---
modules/textual_inversion/preprocess.py | 54 +++++++++++++++++++++------------
1 file changed, 35 insertions(+), 19 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 886cf0c3..2743bdeb 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,5 +1,6 @@
import os
from PIL import Image, ImageOps
+import math
import platform
import sys
import tqdm
@@ -38,6 +39,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
+ split_threshold = 0.5
+ overlap_ratio = 0.2
assert src != dst, 'same directory specified as source and destination'
@@ -78,6 +81,29 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index)
+ def split_pic(image, inverse_xy):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
+ else:
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
+
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
@@ -89,26 +115,16 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
if shared.state.interrupted:
break
- ratio = img.height / img.width
- is_tall = ratio > 1.35
- is_wide = ratio < 1 / 1.35
-
- if process_split and is_tall:
- img = img.resize((width, height * img.height // img.width))
-
- top = img.crop((0, 0, width, height))
- save_pic(top, index)
-
- bot = img.crop((0, img.height - height, width, img.height))
- save_pic(bot, index)
- elif process_split and is_wide:
- img = img.resize((width * img.width // img.height, height))
-
- left = img.crop((0, 0, width, height))
- save_pic(left, index)
+ if img.height > img.width:
+ ratio = (img.width * height) / (img.height * width)
+ inverse_xy = False
+ else:
+ ratio = (img.height * width) / (img.width * height)
+ inverse_xy = True
- right = img.crop((img.width - width, 0, img.width, height))
- save_pic(right, index)
+ if process_split and ratio < 1.0 and ratio <= split_threshold:
+ for splitted in split_pic(img, inverse_xy):
+ save_pic(splitted, index)
else:
img = images.resize_image(1, img, width, height)
save_pic(img, index)
--
cgit v1.2.3
From 85dd62c4c7635b8e21a75f140d093036069e97a1 Mon Sep 17 00:00:00 2001
From: Milly
Date: Thu, 20 Oct 2022 22:56:45 +0900
Subject: train: ui: added `Split image threshold` and `Split image overlap
ratio` to preprocess
---
modules/textual_inversion/preprocess.py | 10 +++++-----
modules/ui.py | 16 ++++++++++++++--
2 files changed, 19 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 2743bdeb..c8df8aa0 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -12,7 +12,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
try:
if process_caption:
shared.interrogator.load()
@@ -22,7 +22,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
- preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+ preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
finally:
@@ -34,13 +34,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
-def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
width = process_width
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
- split_threshold = 0.5
- overlap_ratio = 0.2
+ split_threshold = max(0.0, min(1.0, split_threshold))
+ overlap_ratio = max(0.0, min(0.9, overlap_ratio))
assert src != dst, 'same directory specified as source and destination'
diff --git a/modules/ui.py b/modules/ui.py
index a2dbd41e..bc7f3330 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1240,10 +1240,14 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
- process_split = gr.Checkbox(label='Split oversized images into two')
+ process_split = gr.Checkbox(label='Split oversized images')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
+ with gr.Row(visible=False) as process_split_extra_row:
+ process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
+ process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1251,6 +1255,12 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
run_preprocess = gr.Button(value="Preprocess", variant='primary')
+ process_split.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_split],
+ outputs=[process_split_extra_row],
+ )
+
with gr.Tab(label="Train"):
gr.HTML(value="Train an embedding; must specify a directory with a set of 1:1 ratio images
")
with gr.Row():
@@ -1327,7 +1337,9 @@ def create_ui(wrap_gradio_gpu_call):
process_flip,
process_split,
process_caption,
- process_caption_deepbooru
+ process_caption_deepbooru,
+ process_split_threshold,
+ process_overlap_ratio,
],
outputs=[
ti_output,
--
cgit v1.2.3
From d8acd34f66ab35a91f10d66330bcc95a83bfcac6 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Thu, 20 Oct 2022 23:43:03 +0900
Subject: generalized some functions and option for ignoring first layer
---
modules/hypernetworks/hypernetwork.py | 23 +++++++++++++++--------
1 file changed, 15 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 7d617680..3a44b377 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -21,21 +21,27 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
-
+ activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish}
+
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
-
+
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
- if activation_func == "relu":
- linears.append(torch.nn.ReLU())
- if activation_func == "leakyrelu":
- linears.append(torch.nn.LeakyReLU())
+ # if skip_first_layer because first parameters potentially contain negative values
+ if i < 1: continue
+ if activation_func in HypernetworkModule.activation_dict:
+ linears.append(HypernetworkModule.activation_dict[activation_func]())
+ else:
+ print("Invalid key {} encountered as activation function!".format(activation_func))
+ # if use_dropout:
+ linears.append(torch.nn.Dropout(p=0.3))
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if isinstance(layer, torch.nn.Linear):
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
@@ -298,7 +304,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ # if optimizer == "Adam": or else Adam / AdamW / etc...
+ optimizer = torch.optim.Adam(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
--
cgit v1.2.3
From a71e0212363979c7cbbb797c9fbd5f8cd03b29d3 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Thu, 20 Oct 2022 23:48:52 +0900
Subject: only linear
---
modules/hypernetworks/hypernetwork.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3a44b377..905cbeef 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -35,13 +35,13 @@ class HypernetworkModule(torch.nn.Module):
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values
- if i < 1: continue
+ # if i < 1: continue
if activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]())
else:
print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout:
- linears.append(torch.nn.Dropout(p=0.3))
+ # linears.append(torch.nn.Dropout(p=0.3))
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if isinstance(layer, torch.nn.Linear):
layer_structure += [layer.weight, layer.bias]
return layer_structure
@@ -304,8 +304,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- # if optimizer == "Adam": or else Adam / AdamW / etc...
- optimizer = torch.optim.Adam(weights, lr=scheduler.learn_rate)
+ # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
--
cgit v1.2.3
From d07cb46f34b3d9fe7a78b102f899ebef352ea56b Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Thu, 20 Oct 2022 23:58:52 +0800
Subject: inspiration pull request
---
modules/inspiration.py | 122 +++++++++++++++++++++++++++++++++++++++++++++++++
modules/shared.py | 1 +
modules/ui.py | 13 ++++--
3 files changed, 131 insertions(+), 5 deletions(-)
create mode 100644 modules/inspiration.py
(limited to 'modules')
diff --git a/modules/inspiration.py b/modules/inspiration.py
new file mode 100644
index 00000000..456bfcb5
--- /dev/null
+++ b/modules/inspiration.py
@@ -0,0 +1,122 @@
+import os
+import random
+import gradio
+inspiration_path = "inspiration"
+inspiration_system_path = os.path.join(inspiration_path, "system")
+def read_name_list(file):
+ if not os.path.exists(file):
+ return []
+ f = open(file, "r")
+ ret = []
+ line = f.readline()
+ while len(line) > 0:
+ line = line.rstrip("\n")
+ ret.append(line)
+ print(ret)
+ return ret
+
+def save_name_list(file, name):
+ print(file)
+ f = open(file, "a")
+ f.write(name + "\n")
+
+def get_inspiration_images(source, types):
+ path = os.path.join(inspiration_path , types)
+ if source == "Favorites":
+ names = read_name_list(os.path.join(inspiration_system_path, types + "_faverites.txt"))
+ names = random.sample(names, 25)
+ elif source == "Abandoned":
+ names = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt"))
+ names = random.sample(names, 25)
+ elif source == "Exclude abandoned":
+ abondened = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt"))
+ all_names = os.listdir(path)
+ names = []
+ while len(names) < 25:
+ name = random.choice(all_names)
+ if name not in abondened:
+ names.append(name)
+ else:
+ names = random.sample(os.listdir(path), 25)
+ names = random.sample(names, 25)
+ image_list = []
+ for a in names:
+ image_path = os.path.join(path, a)
+ images = os.listdir(image_path)
+ image_list.append(os.path.join(image_path, random.choice(images)))
+ return image_list, names
+
+def select_click(index, types, name_list):
+ name = name_list[int(index)]
+ path = os.path.join(inspiration_path, types, name)
+ images = os.listdir(path)
+ return name, [os.path.join(path, x) for x in images]
+
+def give_up_click(name, types):
+ file = os.path.join(inspiration_system_path, types + "_abandoned.txt")
+ name_list = read_name_list(file)
+ if name not in name_list:
+ save_name_list(file, name)
+
+def collect_click(name, types):
+ file = os.path.join(inspiration_system_path, types + "_faverites.txt")
+ print(file)
+ name_list = read_name_list(file)
+ print(name_list)
+ if name not in name_list:
+ save_name_list(file, name)
+
+def moveout_click(name, types):
+ file = os.path.join(inspiration_system_path, types + "_faverites.txt")
+ name_list = read_name_list(file)
+ if name not in name_list:
+ save_name_list(file, name)
+
+def source_change(source):
+ if source == "Abandoned" or source == "Favorites":
+ return gradio.Button.update(visible=True, value=f"Move out {source}")
+ else:
+ return gradio.Button.update(visible=False)
+
+def ui(gr, opts):
+ with gr.Blocks(analytics_enabled=False) as inspiration:
+ flag = os.path.exists(inspiration_path)
+ if flag:
+ types = os.listdir(inspiration_path)
+ types = [x for x in types if x != "system"]
+ flag = len(types) > 0
+ if not flag:
+ os.mkdir(inspiration_path)
+ gr.HTML("""
+ "
+ """)
+ return inspiration
+ if not os.path.exists(inspiration_system_path):
+ os.mkdir(inspiration_system_path)
+ gallery, names = get_inspiration_images("Exclude abandoned", types[0])
+ with gr.Row():
+ with gr.Column(scale=2):
+ inspiration_gallery = gr.Gallery(gallery, show_label=False, elem_id="inspiration_gallery").style(grid=5, height='auto')
+ with gr.Column(scale=1):
+ types = gr.Dropdown(choices=types, value=types[0], label="Type", visible=len(types) > 1)
+ with gr.Row():
+ source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source")
+ get_inspiration = gr.Button("Get inspiration")
+ name = gr.Textbox(show_label=False, interactive=False)
+ with gr.Row():
+ send_to_txt2img = gr.Button('to txt2img')
+ send_to_img2img = gr.Button('to img2img')
+ style_gallery = gr.Gallery(show_label=False, elem_id="inspiration_style_gallery").style(grid=2, height='auto')
+
+ collect = gr.Button('Collect')
+ give_up = gr.Button("Don't show any more")
+ moveout = gr.Button("Move out", visible=False)
+ with gr.Row():
+ select_button = gr.Button('set button', elem_id="inspiration_select_button")
+ name_list = gr.State(names)
+ source.change(source_change, inputs=[source], outputs=[moveout])
+ get_inspiration.click(get_inspiration_images, inputs=[source, types], outputs=[inspiration_gallery, name_list])
+ select_button.click(select_click, _js="inspiration_selected", inputs=[name, types, name_list], outputs=[name, style_gallery])
+ give_up.click(give_up_click, inputs=[name, types], outputs=None)
+ collect.click(collect_click, inputs=[name, types], outputs=None)
+ return inspiration
diff --git a/modules/shared.py b/modules/shared.py
index faede821..ae033710 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -78,6 +78,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
+parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
cmd_opts = parser.parse_args()
restricted_opts = [
diff --git a/modules/ui.py b/modules/ui.py
index a2dbd41e..6a0a3c3b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -41,7 +41,8 @@ from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui
-import modules.images_history as img_his
+import modules.images_history as images_history
+import modules.inspiration as inspiration
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -1082,9 +1083,9 @@ def create_ui(wrap_gradio_gpu_call):
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
-
+
with gr.Group():
- extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
+ extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers] , value=shared.sd_upscalers[0].name, type="index")
with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
@@ -1178,7 +1179,8 @@ def create_ui(wrap_gradio_gpu_call):
"i2i":img2img_paste_fields
}
- images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
+ browser_interface = images_history.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
+ inspiration_interface = inspiration.ui(gr, opts)
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
@@ -1595,7 +1597,8 @@ Requested path was: {f}
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
- (images_history, "History", "images_history"),
+ (browser_interface, "History", "images_history"),
+ (inspiration_interface, "Inspiration", "inspiration"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
(settings_interface, "Settings", "settings"),
--
cgit v1.2.3
From 108be15500aac590b4e00420635d7b61fccfa530 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Fri, 21 Oct 2022 01:00:41 +0900
Subject: fix bugs and optimizations
---
modules/hypernetworks/hypernetwork.py | 105 +++++++++++++++++++---------------
1 file changed, 59 insertions(+), 46 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 905cbeef..893ba110 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values
# if i < 1: continue
+ if add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
if activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]())
else:
print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout:
# linears.append(torch.nn.Dropout(p=0.3))
- if add_layer_norm:
- linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
self.linear = torch.nn.Sequential(*linears)
@@ -115,11 +115,24 @@ class Hypernetwork:
for k, layers in self.layers.items():
for layer in layers:
- layer.train()
res += layer.trainables()
return res
+ def eval(self):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.eval()
+ for items in self.weights():
+ items.requires_grad = False
+
+ def train(self):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.train()
+ for items in self.weights():
+ items.requires_grad = True
+
def save(self, filename):
state_dict = {}
@@ -290,10 +303,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork
- weights = hypernetwork.weights()
- for weight in weights:
- weight.requires_grad = True
-
losses = torch.zeros((32,))
last_saved_file = "
"
@@ -304,10 +313,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
- optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
+ hypernetwork.train()
for i, entries in pbar:
hypernetwork.step = i + ititial_step
@@ -328,8 +337,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item()
- optimizer.zero_grad()
+ optimizer.zero_grad(set_to_none=True)
loss.backward()
+ del loss
optimizer.step()
mean_loss = losses.mean()
if torch.isnan(mean_loss):
@@ -346,44 +356,47 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
+ torch.cuda.empty_cache()
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+ with torch.no_grad():
+ hypernetwork.eval()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
- optimizer.zero_grad()
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
-
- preview_text = p.prompt
-
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images)>0 else None
-
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
-
- if image is not None:
- shared.state.current_image = image
- image.save(last_saved_image)
- last_saved_image += f", prompt: {preview_text}"
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images)>0 else None
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ if image is not None:
+ shared.state.current_image = image
+ image.save(last_saved_image)
+ last_saved_image += f", prompt: {preview_text}"
+
+ hypernetwork.train()
shared.state.job_no = hypernetwork.step
--
cgit v1.2.3
From f89829ec3a0baceb445451ad98d4fb4323e922aa Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Fri, 21 Oct 2022 01:37:11 +0900
Subject: Revert "fix bugs and optimizations"
This reverts commit 108be15500aac590b4e00420635d7b61fccfa530.
---
modules/hypernetworks/hypernetwork.py | 105 +++++++++++++++-------------------
1 file changed, 46 insertions(+), 59 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 893ba110..905cbeef 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values
# if i < 1: continue
- if add_layer_norm:
- linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
if activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]())
else:
print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout:
# linears.append(torch.nn.Dropout(p=0.3))
+ if add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
self.linear = torch.nn.Sequential(*linears)
@@ -115,24 +115,11 @@ class Hypernetwork:
for k, layers in self.layers.items():
for layer in layers:
+ layer.train()
res += layer.trainables()
return res
- def eval(self):
- for k, layers in self.layers.items():
- for layer in layers:
- layer.eval()
- for items in self.weights():
- items.requires_grad = False
-
- def train(self):
- for k, layers in self.layers.items():
- for layer in layers:
- layer.train()
- for items in self.weights():
- items.requires_grad = True
-
def save(self, filename):
state_dict = {}
@@ -303,6 +290,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
+
losses = torch.zeros((32,))
last_saved_file = ""
@@ -313,10 +304,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate)
+ # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- hypernetwork.train()
for i, entries in pbar:
hypernetwork.step = i + ititial_step
@@ -337,9 +328,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item()
- optimizer.zero_grad(set_to_none=True)
+ optimizer.zero_grad()
loss.backward()
- del loss
optimizer.step()
mean_loss = losses.mean()
if torch.isnan(mean_loss):
@@ -356,47 +346,44 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- torch.cuda.empty_cache()
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
- with torch.no_grad():
- hypernetwork.eval()
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
-
- preview_text = p.prompt
-
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images)>0 else None
-
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
-
- if image is not None:
- shared.state.current_image = image
- image.save(last_saved_image)
- last_saved_image += f", prompt: {preview_text}"
-
- hypernetwork.train()
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images)>0 else None
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ if image is not None:
+ shared.state.current_image = image
+ image.save(last_saved_image)
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
--
cgit v1.2.3
From 92a17a7a4a13fceb3c3e25a2e854b2a7dd6eb5df Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Thu, 20 Oct 2022 09:45:03 -0700
Subject: Made dummy latents smaller. Minor code cleanups
---
modules/processing.py | 7 ++++---
modules/sd_samplers.py | 6 ++++--
2 files changed, 8 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 3caac25e..539cde38 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -557,7 +557,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
- image_conditioning = torch.zeros(x.shape[0], 5, x.shape[-2], x.shape[-1], dtype=x.dtype, device=x.device)
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
return image_conditioning
@@ -759,8 +760,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
else:
self.image_conditioning = torch.zeros(
- self.init_latent.shape[0], 5, self.init_latent.shape[-2], self.init_latent.shape[-1],
- dtype=self.init_latent.dtype,
+ self.init_latent.shape[0], 5, 1, 1,
+ dtype=self.init_latent.dtype,
device=self.init_latent.device
)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index c21be26e..cc682593 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -138,7 +138,7 @@ class VanillaStableDiffusionSampler:
if self.stop_at is not None and self.step > self.stop_at:
raise InterruptedException
- # Have to unwrap the inpainting conditioning here to perform pre-preocessing
+ # Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None
if isinstance(cond, dict):
image_conditioning = cond["c_concat"][0]
@@ -146,7 +146,7 @@ class VanillaStableDiffusionSampler:
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
@@ -165,6 +165,8 @@ class VanillaStableDiffusionSampler:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
+ # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
+ # Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
--
cgit v1.2.3
From d1cb08bfb221cd1b0cfc6078162b4e206ea80a5c Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Thu, 20 Oct 2022 22:49:06 +0300
Subject: fix skip and interrupt for highres. fix option
---
modules/processing.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index bcb0c32c..6324ca91 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -587,9 +587,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
-
- return samples
+ return self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) or samples
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
--
cgit v1.2.3
From 708c3a7bd8ce68cbe1aa7c268e5a4b1980affc9f Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Thu, 20 Oct 2022 13:28:43 -0700
Subject: Added PLMS hijack and made sure to always replace methods
---
modules/sd_hijack_inpainting.py | 163 ++++++++++++++++++++++++++++++++++++++--
modules/sd_models.py | 3 +-
2 files changed, 157 insertions(+), 9 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index d4d28d2e..43938071 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -1,16 +1,14 @@
import torch
-import numpy as np
-from tqdm import tqdm
-from einops import rearrange, repeat
+from einops import repeat
from omegaconf import ListConfig
-from types import MethodType
-
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
# =================================================================================================
@@ -19,7 +17,7 @@ from ldm.models.diffusion.ddim import DDIMSampler, noise_like
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
# =================================================================================================
@torch.no_grad()
-def sample(self,
+def sample_ddim(self,
S,
batch_size,
shape,
@@ -132,6 +130,153 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return x_prev, pred_x0
+# =================================================================================================
+# Monkey patch PLMSSampler methods.
+# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
+# Adapted from:
+# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
+# =================================================================================================
+@torch.no_grad()
+def sample_plms(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list):
+ ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+
+@torch.no_grad()
+def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
+ for i in range(len(c[k]))
+ ]
+ else:
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
+
# =================================================================================================
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
# Adapted from:
@@ -175,5 +320,9 @@ def should_hijack_inpainting(checkpoint_info):
def do_inpainting_hijack():
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
+
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
- ldm.models.diffusion.ddim.DDIMSampler.sample = sample
\ No newline at end of file
+ ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
+
+ ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
+ ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
\ No newline at end of file
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 47836d25..7072db08 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -214,8 +214,6 @@ def load_model():
sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
- do_inpainting_hijack()
-
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.use_ema = False
@@ -225,6 +223,7 @@ def load_model():
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
+ do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
--
cgit v1.2.3
From d23a46ceaa76af2847f11172f32c92665c268b1b Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Thu, 20 Oct 2022 23:49:14 +0300
Subject: Different approach to skip/interrupt with highres fix
---
modules/processing.py | 4 +++-
modules/sd_samplers.py | 4 ++++
2 files changed, 7 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 6324ca91..bcb0c32c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -587,7 +587,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- return self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) or samples
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
+
+ return samples
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index b58e810b..7ff77c01 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -196,6 +196,7 @@ class VanillaStableDiffusionSampler:
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
+ self.last_latent = x
self.step = 0
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
@@ -206,6 +207,7 @@ class VanillaStableDiffusionSampler:
self.initialize(p)
self.init_latent = None
+ self.last_latent = x
self.step = 0
steps = steps or p.steps
@@ -388,6 +390,7 @@ class KDiffusionSampler:
extra_params_kwargs['sigmas'] = sigma_sched
self.model_wrap_cfg.init_latent = x
+ self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
@@ -414,6 +417,7 @@ class KDiffusionSampler:
else:
extra_params_kwargs['sigmas'] = sigmas
+ self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
--
cgit v1.2.3
From 49533eed9e3aad19e9868ee140708baec4fd44be Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Thu, 20 Oct 2022 16:01:27 -0700
Subject: XY grid correctly re-assignes model when config changes
---
modules/sd_models.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 7072db08..fea84630 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -204,9 +204,9 @@ def load_model_weights(model, checkpoint_info):
model.sd_checkpoint_info = checkpoint_info
-def load_model():
+def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
- checkpoint_info = select_checkpoint()
+ checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
@@ -249,7 +249,7 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
- shared.sd_model = load_model()
+ shared.sd_model = load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
--
cgit v1.2.3
From a3b047b7c74dc6ca07f40aee778997fc1889d72f Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Thu, 20 Oct 2022 19:28:58 -0500
Subject: add settings option to toggle button visibility
---
modules/shared.py | 1 +
modules/ui.py | 2 +-
2 files changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index faede821..7e9c2696 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -300,6 +300,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
+ "trash_prompt_visible": OptionInfo(True, "Show trash prompt button"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
diff --git a/modules/ui.py b/modules/ui.py
index bde546cc..13c0b4ca 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -509,7 +509,7 @@ def create_toprow(is_img2img):
paste = gr.Button(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt")
+ trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt", visible=opts.trash_prompt_visible)
token_counter = gr.HTML(value=" ", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
--
cgit v1.2.3
From 45872181902ada06267e2de601586d512cf5df1a Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 09:00:39 +0300
Subject: updated readme and some small stylistic changes to code
---
modules/processing.py | 14 ++++++--------
modules/sd_hijack_inpainting.py | 3 +++
2 files changed, 9 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 539cde38..21786968 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -540,11 +540,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
-
- def create_dummy_mask(self, x, first_phase: bool = False):
+ def create_dummy_mask(self, x, width=None, height=None):
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- height = self.firstphase_height if first_phase else self.height
- width = self.firstphase_width if first_phase else self.width
+ height = height or self.height
+ width = width or self.width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
@@ -571,7 +570,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, first_phase=True))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
@@ -634,6 +633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpainting_mask_invert = inpainting_mask_invert
self.mask = None
self.nmask = None
+ self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
@@ -735,9 +735,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- conditioning_key = self.sampler.conditioning_key
-
- if conditioning_key in {'hybrid', 'concat'}:
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
if self.image_mask is not None:
conditioning_mask = np.array(self.image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 43938071..fd92a335 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -301,6 +301,7 @@ def get_unconditional_conditioning(self, batch_size, null_label=None):
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
+
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
@@ -314,9 +315,11 @@ class LatentInpaintDiffusion(LatentDiffusion):
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
+
def should_hijack_inpainting(checkpoint_info):
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
+
def do_inpainting_hijack():
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
--
cgit v1.2.3
From 74088c2a06a975092806362aede22f82716cb011 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 20 Oct 2022 08:18:02 +0300
Subject: allow float sizes for hypernet's layer_structure
---
modules/hypernetworks/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 08f75f15..e0741d08 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -15,7 +15,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
- layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure)))
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name,
--
cgit v1.2.3
From 60872c5b404114336f9ca0c671ba88fa4a8201c9 Mon Sep 17 00:00:00 2001
From: winterspringsummer
Date: Thu, 20 Oct 2022 19:10:32 +0900
Subject: Fixed path issue while extras batch processing
---
modules/extras.py | 12 ++++++++----
1 file changed, 8 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index b853fa5b..f9796624 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -118,10 +118,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
-
- images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
- forced_filename=image_name if opts.use_original_name_batch else None)
+
+ if opts.use_original_name_batch and image_name != None:
+ basename = os.path.splitext(os.path.basename(image_name))[0]
+ else:
+ basename = ''
+
+ images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo:
image.info = existing_pnginfo
--
cgit v1.2.3
From fb5a8cf0d9ed027ea3aa2e5422c946d8e6e72efe Mon Sep 17 00:00:00 2001
From: winterspringsummer
Date: Thu, 20 Oct 2022 21:31:29 +0900
Subject: Added try except to extras batch from directory
---
modules/extras.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index f9796624..0d817cf9 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -41,7 +41,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return outputs, "Please select an input directory.", ''
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
for img in image_list:
- image = Image.open(img)
+ try:
+ image = Image.open(img)
+ except Exception:
+ continue
imageArr.append(image)
imageNameArr.append(img)
else:
@@ -122,10 +125,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
- basename = ''
+ basename = None
- images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
+ images.save_image(image, path=outpath, basename='', seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=basename)
if opts.enable_pnginfo:
image.info = existing_pnginfo
--
cgit v1.2.3
From a13c3bed3cec27afe3c015d3d62db36e25b10d1f Mon Sep 17 00:00:00 2001
From: winterspringsummer
Date: Thu, 20 Oct 2022 21:43:27 +0900
Subject: Fixed path issue while extras batch processing
---
modules/extras.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index 0d817cf9..ac85142c 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -125,10 +125,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
- basename = None
+ basename = ''
- images.save_image(image, path=outpath, basename='', seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=basename)
+ images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo:
image.info = existing_pnginfo
--
cgit v1.2.3
From 9d71eef02e7395e179b8d5e61e6d91ddd8928d2e Mon Sep 17 00:00:00 2001
From: winterspringsummer
Date: Fri, 21 Oct 2022 09:23:13 +0900
Subject: sort file list in alphabetical ordering in extras
---
modules/extras.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index ac85142c..22c5a1c1 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -39,7 +39,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '':
return outputs, "Please select an input directory.", ''
- image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
+ image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
for img in image_list:
try:
image = Image.open(img)
--
cgit v1.2.3
From c23f666dba2b484d521d2dc4be91cf9e09312647 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 09:47:43 +0300
Subject: a more strict check for activation type and a more reasonable check
for type of layer in hypernets
---
modules/hypernetworks/hypernetwork.py | 12 +++++++++---
1 file changed, 9 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 7d617680..84e7e350 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -32,10 +32,16 @@ class HypernetworkModule(torch.nn.Module):
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+
if activation_func == "relu":
linears.append(torch.nn.ReLU())
- if activation_func == "leakyrelu":
+ elif activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU())
+ elif activation_func == 'linear' or activation_func is None:
+ pass
+ else:
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if type(layer) == torch.nn.Linear:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
@@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if type(layer) == torch.nn.Linear:
layer_structure += [layer.weight, layer.bias]
return layer_structure
--
cgit v1.2.3
From 7157e5d064741fa57ca81a2c6432a651f21ee82f Mon Sep 17 00:00:00 2001
From: Patryk Wychowaniec
Date: Thu, 20 Oct 2022 19:22:59 +0200
Subject: interrogate: Fix CLIP-interrogation on CPU
Currently, trying to perform CLIP interrogation on a CPU fails, saying:
```
RuntimeError: "slow_conv2d_cpu" not implemented for 'Half'
```
This merge request fixes this issue by detecting whether the target
device is CPU and, if so, force-enabling `--no-half` and passing
`device="cpu"` to `clip.load()` (which then does some extra tricks to
ensure it works correctly on CPU).
---
modules/interrogate.py | 12 +++++++++---
1 file changed, 9 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 64b91eb4..65b05d34 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -28,9 +28,11 @@ class InterrogateModels:
clip_preprocess = None
categories = None
dtype = None
+ running_on_cpu = None
def __init__(self, content_dir):
self.categories = []
+ self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir):
for filename in os.listdir(content_dir):
@@ -53,7 +55,11 @@ class InterrogateModels:
def load_clip_model(self):
import clip
- model, preprocess = clip.load(clip_model_name)
+ if self.running_on_cpu:
+ model, preprocess = clip.load(clip_model_name, device="cpu")
+ else:
+ model, preprocess = clip.load(clip_model_name)
+
model.eval()
model = model.to(devices.device_interrogate)
@@ -62,14 +68,14 @@ class InterrogateModels:
def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()
- if not shared.cmd_opts.no_half:
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
- if not shared.cmd_opts.no_half:
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(devices.device_interrogate)
--
cgit v1.2.3
From b69c37d25e4ffc56e8f8c247fa2c38b4648cefb7 Mon Sep 17 00:00:00 2001
From: guaneec
Date: Thu, 20 Oct 2022 22:21:12 +0800
Subject: Allow datasets with only 1 image in TI
---
modules/textual_inversion/dataset.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 23bb4b6a..5b1c5002 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
self.dataset.append(entry)
- assert len(self.dataset) > 1, "No images have been found in the dataset."
+ assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset))
@@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
self.shuffle()
def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
def create_text(self, filename_text):
text = random.choice(self.lines)
--
cgit v1.2.3
From 5245c7a4935f67b677da0f5a1fc2b74c074aa0e2 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Wed, 19 Oct 2022 12:21:32 -0700
Subject: Issue #2921-Give PNG info to Hypernet previews.
---
modules/hypernetworks/hypernetwork.py | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 84e7e350..68c8f26d 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -256,6 +256,9 @@ def stack_conds(conds):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ # images is required here to give training previews their infotext. Importing this at the very top causes a circular dependency.
+ from modules import images
+
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -298,6 +301,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
last_saved_file = ""
last_saved_image = ""
+ forced_filename = ""
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
@@ -345,7 +349,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+ forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
@@ -381,7 +386,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None:
shared.state.current_image = image
- image.save(last_saved_image)
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
--
cgit v1.2.3
From 6014fb8afbe05c8d02fffe7a36a2e48128713bd2 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Wed, 19 Oct 2022 12:22:23 -0700
Subject: Do nothing if image file already exists.
---
modules/images.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index b9589563..550e53ae 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -416,7 +416,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
path = os.path.join(path, dirname)
- os.makedirs(path, exist_ok=True)
+ try:
+ os.makedirs(path, exist_ok=True)
+ except FileExistsError:
+ # If the file already exists, continue and allow said file to be overwritten.
+ pass
if forced_filename is None:
basecount = get_next_sequence_number(path, basename)
--
cgit v1.2.3
From 4ff274e1e35bb642687253ce744d2cfa738ab293 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Wed, 19 Oct 2022 12:32:22 -0700
Subject: Revise comments.
---
modules/hypernetworks/hypernetwork.py | 2 +-
modules/images.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 68c8f26d..3f96361c 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -256,7 +256,7 @@ def stack_conds(conds):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- # images is required here to give training previews their infotext. Importing this at the very top causes a circular dependency.
+ # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
assert hypernetwork_name, 'hypernetwork not selected'
diff --git a/modules/images.py b/modules/images.py
index 550e53ae..b8834e3c 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -419,7 +419,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
try:
os.makedirs(path, exist_ok=True)
except FileExistsError:
- # If the file already exists, continue and allow said file to be overwritten.
+ # If the file already exists, allow said file to be overwritten.
pass
if forced_filename is None:
--
cgit v1.2.3
From 2273e752fb3e578f1047f6d38b96330b07bf61a9 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Wed, 19 Oct 2022 14:23:48 -0700
Subject: Remove redundant try/except.
---
modules/images.py | 6 +-----
1 file changed, 1 insertion(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index b8834e3c..b9589563 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -416,11 +416,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
path = os.path.join(path, dirname)
- try:
- os.makedirs(path, exist_ok=True)
- except FileExistsError:
- # If the file already exists, allow said file to be overwritten.
- pass
+ os.makedirs(path, exist_ok=True)
if forced_filename is None:
basecount = get_next_sequence_number(path, basename)
--
cgit v1.2.3
From 03a1e288c4973dd2dff57a97469b40f146b6fccf Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 10:13:24 +0300
Subject: turns out LayerNorm also has weight and bias and needs to be
pre-multiplied and trained for hypernets
---
modules/hypernetworks/hypernetwork.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3274a802..b1a5d0c7 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- if type(layer) == torch.nn.Linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
@@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- if type(layer) == torch.nn.Linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer_structure += [layer.weight, layer.bias]
return layer_structure
--
cgit v1.2.3
From bf30673f5132c8f28357b31224c54331e788d3e7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 10:19:25 +0300
Subject: Fix Hypernet infotext string split bug for PR #3283
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 21786968..d1deffa9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else os.path.splitext(os.path.basename(shared.loaded_hypernetwork.filename))[0]),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
--
cgit v1.2.3
From df5706409386cc2e88718bd9101045587c39f8bb Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 16:10:51 +0300
Subject: do not load aesthetic clip model until it's needed add refresh button
for aesthetic embeddings add aesthetic params to images' infotext
---
modules/aesthetic_clip.py | 40 +++++++++++++++++++----
modules/generation_parameters_copypaste.py | 18 +++++++++--
modules/img2img.py | 5 +--
modules/processing.py | 4 +--
modules/sd_models.py | 3 --
modules/txt2img.py | 4 +--
modules/ui.py | 52 ++++++++++++++++++++----------
7 files changed, 88 insertions(+), 38 deletions(-)
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
index 34efa931..8c828541 100644
--- a/modules/aesthetic_clip.py
+++ b/modules/aesthetic_clip.py
@@ -40,6 +40,8 @@ def iter_to_batched(iterable, n=1):
def create_ui():
+ import modules.ui
+
with gr.Group():
with gr.Accordion("Open for Clip Aesthetic!", open=False):
with gr.Row():
@@ -55,6 +57,8 @@ def create_ui():
label="Aesthetic imgs embedding",
value="None")
+ modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
+
with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
placeholder="This text is used to rotate the feature space of the imgs embs",
@@ -66,11 +70,21 @@ def create_ui():
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
+aesthetic_clip_model = None
+
+
+def aesthetic_clip():
+ global aesthetic_clip_model
+
+ if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
+ aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
+ aesthetic_clip_model.cpu()
+
+ return aesthetic_clip_model
+
+
def generate_imgs_embd(name, folder, batch_size):
- # clipModel = CLIPModel.from_pretrained(
- # shared.sd_model.cond_stage_model.clipModel.name_or_path
- # )
- model = shared.clip_model.to(device)
+ model = aesthetic_clip().to(device)
processor = CLIPProcessor.from_pretrained(model.name_or_path)
with torch.no_grad():
@@ -91,7 +105,7 @@ def generate_imgs_embd(name, folder, batch_size):
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
torch.save(embs, path)
- model = model.cpu()
+ model.cpu()
del processor
del embs
gc.collect()
@@ -132,7 +146,7 @@ class AestheticCLIP:
self.image_embs = None
self.load_image_embs(None)
- def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
+ def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
aesthetic_slerp=True, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False):
@@ -145,6 +159,18 @@ class AestheticCLIP:
self.aesthetic_steps = aesthetic_steps
self.load_image_embs(image_embs_name)
+ if self.image_embs_name is not None:
+ p.extra_generation_params.update({
+ "Aesthetic LR": aesthetic_lr,
+ "Aesthetic weight": aesthetic_weight,
+ "Aesthetic steps": aesthetic_steps,
+ "Aesthetic embedding": self.image_embs_name,
+ "Aesthetic slerp": aesthetic_slerp,
+ "Aesthetic text": aesthetic_imgs_text,
+ "Aesthetic text negative": aesthetic_text_negative,
+ "Aesthetic slerp angle": aesthetic_slerp_angle,
+ })
+
def set_skip(self, skip):
self.skip = skip
@@ -168,7 +194,7 @@ class AestheticCLIP:
tokens = torch.asarray(remade_batch_tokens).to(device)
- model = copy.deepcopy(shared.clip_model).to(device)
+ model = copy.deepcopy(aesthetic_clip()).to(device)
model.requires_grad_(True)
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
text_embs_2 = model.get_text_features(
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 0f041449..f73647da 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -4,13 +4,22 @@ import gradio as gr
from modules.shared import script_path
from modules import shared
-re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
+re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update())
+def quote(text):
+ if ',' not in str(text):
+ return text
+
+ text = str(text)
+ text = text.replace('\\', '\\\\')
+ text = text.replace('"', '\\"')
+ return f'"{text}"'
+
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
else:
try:
valtype = type(output.value)
- val = valtype(v)
+
+ if valtype == bool and v == "False":
+ val = False
+ else:
+ val = valtype(v)
+
res.append(gr.update(value=val))
except Exception:
res.append(gr.update())
diff --git a/modules/img2img.py b/modules/img2img.py
index bc7c66bc..eea5199b 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -109,10 +109,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert,
)
- shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
- aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text,
- aesthetic_slerp_angle,
- aesthetic_text_negative)
+ shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/processing.py b/modules/processing.py
index d1deffa9..f0852cd5 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -12,7 +12,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -318,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params)
- generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
+ generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 05a1df28..b1c91b0d 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -234,9 +234,6 @@ def load_model(checkpoint_info=None):
sd_hijack.model_hijack.hijack(sd_model)
- if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path:
- shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path)
-
sd_model.eval()
print(f"Model loaded.")
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 32ed1d8d..1761cfa2 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -36,9 +36,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None,
)
- shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
- aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle,
- aesthetic_text_negative)
+ shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/ui.py b/modules/ui.py
index 381ca925..0d020de6 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -597,27 +597,29 @@ def apply_setting(key, value):
return value
-def create_ui(wrap_gradio_gpu_call):
- import modules.img2img
- import modules.txt2img
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
- def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
- for k, v in args.items():
- setattr(refresh_component, k, v)
+ return gr.update(**(args or {}))
- return gr.update(**(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
+
+def create_ui(wrap_gradio_gpu_call):
+ import modules.img2img
+ import modules.txt2img
- refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn = refresh,
- inputs = [],
- outputs = [refresh_component]
- )
- return refresh_button
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
@@ -802,6 +804,14 @@ def create_ui(wrap_gradio_gpu_call):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
+ (aesthetic_lr, "Aesthetic LR"),
+ (aesthetic_weight, "Aesthetic weight"),
+ (aesthetic_steps, "Aesthetic steps"),
+ (aesthetic_imgs, "Aesthetic embedding"),
+ (aesthetic_slerp, "Aesthetic slerp"),
+ (aesthetic_imgs_text, "Aesthetic text"),
+ (aesthetic_text_negative, "Aesthetic text negative"),
+ (aesthetic_slerp_angle, "Aesthetic slerp angle"),
]
txt2img_preview_params = [
@@ -1077,6 +1087,14 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
+ (aesthetic_lr_im, "Aesthetic LR"),
+ (aesthetic_weight_im, "Aesthetic weight"),
+ (aesthetic_steps_im, "Aesthetic steps"),
+ (aesthetic_imgs_im, "Aesthetic embedding"),
+ (aesthetic_slerp_im, "Aesthetic slerp"),
+ (aesthetic_imgs_text_im, "Aesthetic text"),
+ (aesthetic_text_negative_im, "Aesthetic text negative"),
+ (aesthetic_slerp_angle_im, "Aesthetic slerp angle"),
]
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
--
cgit v1.2.3
From 9286fe53de2eef91f13cc3ad5938ddf67ecc8413 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 16:38:06 +0300
Subject: make aestetic embedding ciompatible with prompts longer than 75
tokens
---
modules/sd_hijack.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 36198a3c..1f8587d1 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -332,8 +332,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
+ z1 = shared.aesthetic_clip(z1, remade_batch_tokens)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
- z = shared.aesthetic_clip(z, remade_batch_tokens)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
--
cgit v1.2.3
From d0ea471b0cdaede163c6e7f6fae8535f5c3cd226 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Fri, 21 Oct 2022 14:04:41 +0100
Subject: Use opts in textual_inversion image_embedding.py for dynamic fonts
---
modules/textual_inversion/image_embedding.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index 898ce3b3..c50b1e7b 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -5,6 +5,7 @@ import zlib
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
from fonts.ttf import Roboto
import torch
+from modules.shared import opts
class EmbeddingEncoder(json.JSONEncoder):
--
cgit v1.2.3
From 306e2ff6ab8f4c7e94ab55f4f08ab8f94d73d287 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Fri, 21 Oct 2022 14:47:21 +0100
Subject: Update image_embedding.py
---
modules/textual_inversion/image_embedding.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index c50b1e7b..ea653806 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -134,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
from math import cos
image = srcimage.copy()
-
+ fontsize = 32
if textfont is None:
try:
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
@@ -151,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image)
- fontsize = 32
+
font = ImageFont.truetype(textfont, fontsize)
padding = 10
--
cgit v1.2.3
From 51e3dc9ccad157d7161b697a246e26c868d46a7c Mon Sep 17 00:00:00 2001
From: timntorres
Date: Fri, 21 Oct 2022 02:11:12 -0700
Subject: Sanitize hypernet name input.
---
modules/hypernetworks/ui.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 266f04f6..e6f50a1f 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -11,6 +11,9 @@ from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
--
cgit v1.2.3
From 19818f023cfafc472c6c241cab0b72896a168481 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Fri, 21 Oct 2022 02:14:02 -0700
Subject: Match hypernet name with filename in all cases.
---
modules/hypernetworks/hypernetwork.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b1a5d0c7..6d392be4 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -340,7 +340,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
pbar.set_description(f"loss: {mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
+ temp = hypernetwork.name
+ # Before saving, change name to match current checkpoint.
+ hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
@@ -405,6 +408,9 @@ Last saved image: {html.escape(last_saved_image)}
hypernetwork.sd_checkpoint = checkpoint.hash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
+ hypernetwork.name = hypernetwork_name
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(filename)
return hypernetwork, filename
--
cgit v1.2.3
From fccad18a59e3c2c33fefbbb1763c6a87a3a68eba Mon Sep 17 00:00:00 2001
From: timntorres
Date: Fri, 21 Oct 2022 02:17:26 -0700
Subject: Refer to Hypernet's name, sensibly, by its name variable.
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index f0852cd5..ff1ec4c9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else os.path.splitext(os.path.basename(shared.loaded_hypernetwork.filename))[0]),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
--
cgit v1.2.3
From 272fa527bbe93143668ffc16838107b7dca35b40 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Fri, 21 Oct 2022 02:41:55 -0700
Subject: Remove unused variable.
---
modules/hypernetworks/hypernetwork.py | 1 -
1 file changed, 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 6d392be4..47d91ea5 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -340,7 +340,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
pbar.set_description(f"loss: {mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
- temp = hypernetwork.name
# Before saving, change name to match current checkpoint.
hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
--
cgit v1.2.3
From 02e4d4694dd9254a6ca9f05c2eb7b01ea508abc7 Mon Sep 17 00:00:00 2001
From: Rcmcpe
Date: Fri, 21 Oct 2022 15:53:35 +0800
Subject: Change option description of unload_models_when_training
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 5c675b80..41d7f08e 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -266,7 +266,7 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('training', "Training"), {
- "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
+ "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
--
cgit v1.2.3
From 704036ff07b71bf86cadcbbff2bcfeebdd1ed3a6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 17:11:42 +0300
Subject: make aspect ratio overlay work regardless of selected localization
---
modules/ui.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 0d020de6..85f95792 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -879,8 +879,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
--
cgit v1.2.3
From ac0aa2b18efeeb9220a5994c8dd54c7cdda7cc40 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 17:35:51 +0300
Subject: loading SD VAE, see PR #3303
---
modules/sd_models.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index b1c91b0d..d99dbce8 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -155,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
+vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
+
+
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
@@ -186,7 +189,7 @@ def load_model_weights(model, checkpoint_info):
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
+ vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
--
cgit v1.2.3
From f49c08ea566385db339c6628f65c3a121033f67c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 18:46:02 +0300
Subject: prevent error spam when processing images without txt files for
captions
---
modules/textual_inversion/preprocess.py | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 17e4ddc1..33eaddb6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -122,11 +122,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
continue
existing_caption = None
-
- try:
- existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
- except Exception as e:
- print(e)
+ existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
+ if os.path.exists(existing_caption_filename):
+ with open(existing_caption_filename, 'r', encoding="utf8") as file:
+ existing_caption = file.read()
if shared.state.interrupted:
break
--
cgit v1.2.3
From bb0f1a2cdae3410a41d06ae878f56e29b8154c41 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Sat, 22 Oct 2022 01:23:00 +0800
Subject: inspiration finished
---
modules/inspiration.py | 192 ++++++++++++++++++++++++++++++++-----------------
modules/shared.py | 6 ++
modules/ui.py | 2 +-
3 files changed, 133 insertions(+), 67 deletions(-)
(limited to 'modules')
diff --git a/modules/inspiration.py b/modules/inspiration.py
index 456bfcb5..f72ebf3a 100644
--- a/modules/inspiration.py
+++ b/modules/inspiration.py
@@ -1,122 +1,182 @@
import os
import random
-import gradio
-inspiration_path = "inspiration"
-inspiration_system_path = os.path.join(inspiration_path, "system")
-def read_name_list(file):
+import gradio
+from modules.shared import opts
+inspiration_system_path = os.path.join(opts.inspiration_dir, "system")
+def read_name_list(file, types=None, keyword=None):
if not os.path.exists(file):
return []
- f = open(file, "r")
ret = []
+ f = open(file, "r")
line = f.readline()
while len(line) > 0:
line = line.rstrip("\n")
- ret.append(line)
- print(ret)
+ if types is not None:
+ dirname = os.path.split(line)
+ if dirname[0] in types and keyword in dirname[1]:
+ ret.append(line)
+ else:
+ ret.append(line)
+ line = f.readline()
return ret
def save_name_list(file, name):
- print(file)
- f = open(file, "a")
- f.write(name + "\n")
+ with open(file, "a") as f:
+ f.write(name + "\n")
-def get_inspiration_images(source, types):
- path = os.path.join(inspiration_path , types)
+def get_types_list():
+ files = os.listdir(opts.inspiration_dir)
+ types = []
+ for x in files:
+ path = os.path.join(opts.inspiration_dir, x)
+ if x[0] == ".":
+ continue
+ if not os.path.isdir(path):
+ continue
+ if path == inspiration_system_path:
+ continue
+ types.append(x)
+ return types
+
+def get_inspiration_images(source, types, keyword):
+ get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num)
if source == "Favorites":
- names = read_name_list(os.path.join(inspiration_system_path, types + "_faverites.txt"))
- names = random.sample(names, 25)
+ names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword)
+ names = random.sample(names, get_num) if len(names) > get_num else names
elif source == "Abandoned":
- names = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt"))
- names = random.sample(names, 25)
- elif source == "Exclude abandoned":
- abondened = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt"))
- all_names = os.listdir(path)
- names = []
- while len(names) < 25:
- name = random.choice(all_names)
- if name not in abondened:
- names.append(name)
+ names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword)
+ print(names)
+ names = random.sample(names, get_num) if len(names) > get_num else names
+ elif source == "Exclude abandoned":
+ abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword)
+ all_names = []
+ for tp in types:
+ name_list = os.listdir(os.path.join(opts.inspiration_dir, tp))
+ all_names += [os.path.join(tp, x) for x in name_list if keyword in x]
+
+ if len(all_names) > get_num:
+ names = []
+ while len(names) < get_num:
+ name = random.choice(all_names)
+ if name not in abandoned:
+ names.append(name)
+ else:
+ names = all_names
else:
- names = random.sample(os.listdir(path), 25)
- names = random.sample(names, 25)
+ all_names = []
+ for tp in types:
+ name_list = os.listdir(os.path.join(opts.inspiration_dir, tp))
+ all_names += [os.path.join(tp, x) for x in name_list if keyword in x]
+ names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names
image_list = []
for a in names:
- image_path = os.path.join(path, a)
+ image_path = os.path.join(opts.inspiration_dir, a)
images = os.listdir(image_path)
- image_list.append(os.path.join(image_path, random.choice(images)))
- return image_list, names
+ image_list.append((os.path.join(image_path, random.choice(images)), a))
+ return image_list, names, ""
-def select_click(index, types, name_list):
+def select_click(index, name_list):
name = name_list[int(index)]
- path = os.path.join(inspiration_path, types, name)
+ path = os.path.join(opts.inspiration_dir, name)
images = os.listdir(path)
- return name, [os.path.join(path, x) for x in images]
+ return name, [os.path.join(path, x) for x in images], ""
-def give_up_click(name, types):
- file = os.path.join(inspiration_system_path, types + "_abandoned.txt")
+def give_up_click(name):
+ file = os.path.join(inspiration_system_path, "abandoned.txt")
name_list = read_name_list(file)
if name not in name_list:
save_name_list(file, name)
+ return "Added to abandoned list"
-def collect_click(name, types):
- file = os.path.join(inspiration_system_path, types + "_faverites.txt")
- print(file)
+def collect_click(name):
+ file = os.path.join(inspiration_system_path, "faverites.txt")
name_list = read_name_list(file)
- print(name_list)
if name not in name_list:
save_name_list(file, name)
+ return "Added to faverite list"
-def moveout_click(name, types):
- file = os.path.join(inspiration_system_path, types + "_faverites.txt")
+def moveout_click(name, source):
+ if source == "Abandoned":
+ file = os.path.join(inspiration_system_path, "abandoned.txt")
+ if source == "Favorites":
+ file = os.path.join(inspiration_system_path, "faverites.txt")
+ else:
+ return None
name_list = read_name_list(file)
- if name not in name_list:
- save_name_list(file, name)
+ os.remove(file)
+ with open(file, "a") as f:
+ for a in name_list:
+ if a != name:
+ f.write(a)
+ return "Moved out {name} from {source} list"
def source_change(source):
- if source == "Abandoned" or source == "Favorites":
- return gradio.Button.update(visible=True, value=f"Move out {source}")
+ if source in ["Abandoned", "Favorites"]:
+ return gradio.update(visible=True), []
else:
- return gradio.Button.update(visible=False)
+ return gradio.update(visible=False), []
+def add_to_prompt(name, prompt):
+ print(name, prompt)
+ name = os.path.basename(name)
+ return prompt + "," + name
-def ui(gr, opts):
+def ui(gr, opts, txt2img_prompt, img2img_prompt):
with gr.Blocks(analytics_enabled=False) as inspiration:
- flag = os.path.exists(inspiration_path)
+ flag = os.path.exists(opts.inspiration_dir)
if flag:
- types = os.listdir(inspiration_path)
- types = [x for x in types if x != "system"]
+ types = get_types_list()
flag = len(types) > 0
- if not flag:
- os.mkdir(inspiration_path)
+ else:
+ os.makedirs(opts.inspiration_dir)
+ if not flag:
gr.HTML("""
- "
+
To activate inspiration function, you need get "inspiration" images first.
+ You can create these images by run "Create inspiration images" script in txt2img page,
you can get the artists or art styles list from here
+
https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
+ download these files, and select these files in the "Create inspiration images" script UI
+ There about 6000 artists and art styles in these files.
This takes server hours depending on your GPU type and how many pictures you generate for each artist/style
+
I suggest at least four images for each
+
You can also download generated pictures from here:
+
https://huggingface.co/datasets/yfszzx/inspiration
+ unzip the file to the project directory of webui
+ and restart webui, and enjoy the joy of creation!
""")
return inspiration
if not os.path.exists(inspiration_system_path):
os.mkdir(inspiration_system_path)
- gallery, names = get_inspiration_images("Exclude abandoned", types[0])
with gr.Row():
with gr.Column(scale=2):
- inspiration_gallery = gr.Gallery(gallery, show_label=False, elem_id="inspiration_gallery").style(grid=5, height='auto')
+ inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto')
with gr.Column(scale=1):
- types = gr.Dropdown(choices=types, value=types[0], label="Type", visible=len(types) > 1)
+ print(types)
+ types = gr.CheckboxGroup(choices=types, value=types)
+ keyword = gr.Textbox("", label="Key word")
with gr.Row():
source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source")
- get_inspiration = gr.Button("Get inspiration")
+ get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button")
name = gr.Textbox(show_label=False, interactive=False)
with gr.Row():
send_to_txt2img = gr.Button('to txt2img')
send_to_img2img = gr.Button('to img2img')
- style_gallery = gr.Gallery(show_label=False, elem_id="inspiration_style_gallery").style(grid=2, height='auto')
-
+ style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto')
collect = gr.Button('Collect')
- give_up = gr.Button("Don't show any more")
+ give_up = gr.Button("Don't show again")
moveout = gr.Button("Move out", visible=False)
- with gr.Row():
+ warning = gr.HTML()
+ with gr.Row(visible=False):
select_button = gr.Button('set button', elem_id="inspiration_select_button")
- name_list = gr.State(names)
- source.change(source_change, inputs=[source], outputs=[moveout])
- get_inspiration.click(get_inspiration_images, inputs=[source, types], outputs=[inspiration_gallery, name_list])
- select_button.click(select_click, _js="inspiration_selected", inputs=[name, types, name_list], outputs=[name, style_gallery])
- give_up.click(give_up_click, inputs=[name, types], outputs=None)
- collect.click(collect_click, inputs=[name, types], outputs=None)
+ name_list = gr.State()
+
+ get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list, keyword])
+ source.change(source_change, inputs=[source], outputs=[moveout, style_gallery])
+ source.change(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None)
+ keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None)
+ select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning])
+ give_up.click(give_up_click, inputs=[name], outputs=[warning])
+ collect.click(collect_click, inputs=[name], outputs=[warning])
+ moveout.click(moveout_click, inputs=[name, source], outputs=[warning])
+ send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt])
+ send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt])
+ send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None)
+ send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None)
return inspiration
diff --git a/modules/shared.py b/modules/shared.py
index ae033710..564b1b8d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -316,6 +316,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
+options_templates.update(options_section(('inspiration', "Inspiration"), {
+ "inspiration_dir": OptionInfo("inspiration", "Directory of inspiration", component_args=hide_dirs),
+ "inspiration_max_samples": OptionInfo(4, "Maximum number of samples, used to determine which folders to skip when continue running the create script", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
+ "inspiration_rows_num": OptionInfo(4, "Rows of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}),
+ "inspiration_cols_num": OptionInfo(8, "Columns of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}),
+}))
class Options:
data = None
diff --git a/modules/ui.py b/modules/ui.py
index 6a0a3c3b..b651eb9c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1180,7 +1180,7 @@ def create_ui(wrap_gradio_gpu_call):
}
browser_interface = images_history.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
- inspiration_interface = inspiration.ui(gr, opts)
+ inspiration_interface = inspiration.ui(gr, opts, txt2img_prompt, img2img_prompt)
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
--
cgit v1.2.3
From 9ba372de90df81c4f1e992d8b33ae17c6630de95 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Fri, 21 Oct 2022 13:55:42 -0500
Subject: initial work on getting prompts cleared on the backend and
synchronizing token counter
---
modules/ui.py | 29 +++++++++++++++++++----------
1 file changed, 19 insertions(+), 10 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index d2cb528e..2748a2e0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -429,15 +429,16 @@ def create_seed_inputs():
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
-# setup button for clearing prompt input boxes on client side of webui
-def connect_trash_prompt(dummy_component, button, is_img2img):
+def clear_prompt(prompt):
+ print(f"type: '{prompt}'")
+ print(prompt)
+
+ # update_token_counter(prompt, steps)
+ return ""
+
+def connect_trash_prompt(prompt, confirmed):
+ if(confirmed): clear_prompt(prompt)
- button.click(
- fn=lambda: print("Clearing prompt"),
- _js="trash_prompt",
- inputs=[],
- outputs=[],
- )
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
@@ -640,7 +641,16 @@ def create_ui(wrap_gradio_gpu_call):
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
- connect_trash_prompt(dummy_component, trash_prompt_button, False)
+
+
+ trash_prompt_button.click(
+ # fn=lambda: print("Clearing prompt"),
+ _js="trash_prompt",
+ fn=lambda: clear_prompt(),
+ inputs=[txt2img_prompt, dummy_component],
+ outputs=[txt2img_prompt, dummy_component],
+ )
+
with gr.Row(elem_id='txt2img_progress_row'):
with gr.Column(scale=1):
@@ -848,7 +858,6 @@ def create_ui(wrap_gradio_gpu_call):
img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\
token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True)
- connect_trash_prompt(dummy_component,trash_prompt_button, True)
with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
--
cgit v1.2.3
From ee0505dd0092ae7073b77aba93a858bda000dc60 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Fri, 21 Oct 2022 14:24:14 -0500
Subject: only delete prompt on back end and remove client-side deletion
---
modules/ui.py | 29 +++++++++++++++--------------
1 file changed, 15 insertions(+), 14 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 2748a2e0..90c338da 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -429,15 +429,21 @@ def create_seed_inputs():
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
-def clear_prompt(prompt):
- print(f"type: '{prompt}'")
- print(prompt)
- # update_token_counter(prompt, steps)
- return ""
+def connect_trash_prompt(_, confirmed):
+ if(confirmed):
+ # update_token_counter(prompt, steps)
+ return ["", confirmed]
-def connect_trash_prompt(prompt, confirmed):
- if(confirmed): clear_prompt(prompt)
+def trash_prompt_click(button, prompt):
+ dummy_component = gradio.Label(visible=False)
+
+ button.click(
+ _js="trash_prompt",
+ fn=connect_trash_prompt,
+ inputs=[prompt, dummy_component],
+ outputs=[prompt, dummy_component],
+ )
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
@@ -643,13 +649,7 @@ def create_ui(wrap_gradio_gpu_call):
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
- trash_prompt_button.click(
- # fn=lambda: print("Clearing prompt"),
- _js="trash_prompt",
- fn=lambda: clear_prompt(),
- inputs=[txt2img_prompt, dummy_component],
- outputs=[txt2img_prompt, dummy_component],
- )
+ trash_prompt_click(trash_prompt_button, txt2img_prompt)
with gr.Row(elem_id='txt2img_progress_row'):
@@ -858,6 +858,7 @@ def create_ui(wrap_gradio_gpu_call):
img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\
token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True)
+ trash_prompt_click(trash_prompt_button, img2img_prompt)
with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
--
cgit v1.2.3
From de70ddaf58fae98c561738a54f574abfa14cd8d1 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Fri, 21 Oct 2022 15:00:35 -0500
Subject: update token counter when clearing prompt
---
modules/ui.py | 17 +++++++----------
1 file changed, 7 insertions(+), 10 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 90c338da..d3a89bf7 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -430,19 +430,16 @@ def create_seed_inputs():
-def connect_trash_prompt(_, confirmed):
+def connect_trash_prompt(_prompt, confirmed, _token_counter):
if(confirmed):
- # update_token_counter(prompt, steps)
- return ["", confirmed]
-
-def trash_prompt_click(button, prompt):
- dummy_component = gradio.Label(visible=False)
+ return ["", confirmed, update_token_counter("", 1)]
+def trash_prompt_click(button, prompt, _dummy_confirmed, token_counter):
button.click(
_js="trash_prompt",
fn=connect_trash_prompt,
- inputs=[prompt, dummy_component],
- outputs=[prompt, dummy_component],
+ inputs=[prompt, _dummy_confirmed, token_counter],
+ outputs=[prompt, _dummy_confirmed, token_counter],
)
@@ -649,7 +646,6 @@ def create_ui(wrap_gradio_gpu_call):
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
- trash_prompt_click(trash_prompt_button, txt2img_prompt)
with gr.Row(elem_id='txt2img_progress_row'):
@@ -720,6 +716,7 @@ def create_ui(wrap_gradio_gpu_call):
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
+ trash_prompt_click(trash_prompt_button, txt2img_prompt, dummy_component, token_counter)
txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
@@ -858,7 +855,6 @@ def create_ui(wrap_gradio_gpu_call):
img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\
token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True)
- trash_prompt_click(trash_prompt_button, img2img_prompt)
with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -958,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call):
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
+ trash_prompt_click(trash_prompt_button, img2img_prompt, dummy_component, token_counter)
img2img_prompt_img.change(
fn=modules.images.image_data,
--
cgit v1.2.3
From 9e40520f00d836cfa93187f7f1e81e2a7bd100b9 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Fri, 21 Oct 2022 15:13:12 -0500
Subject: refactor internal terminology to use 'clear' instead of 'trash' like
#2728
---
modules/shared.py | 2 +-
modules/ui.py | 22 +++++++++++-----------
2 files changed, 12 insertions(+), 12 deletions(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 1585d532..ab5a0e9a 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -317,7 +317,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
- "trash_prompt_visible": OptionInfo(True, "Show trash prompt button"),
+ "clear_prompt_visible": OptionInfo(True, "Show clear prompt button"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
diff --git a/modules/ui.py b/modules/ui.py
index d3a89bf7..31150800 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -88,7 +88,7 @@ folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
-trash_prompt_symbol = '\U0001F5D1' #
+clear_prompt_symbol = '\U0001F5D1' # 🗑️
def plaintext_to_html(text):
@@ -430,14 +430,14 @@ def create_seed_inputs():
-def connect_trash_prompt(_prompt, confirmed, _token_counter):
+def clear_prompt(_prompt, confirmed, _token_counter):
if(confirmed):
return ["", confirmed, update_token_counter("", 1)]
-def trash_prompt_click(button, prompt, _dummy_confirmed, token_counter):
+def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter):
button.click(
- _js="trash_prompt",
- fn=connect_trash_prompt,
+ _js="clear_prompt",
+ fn=clear_prompt,
inputs=[prompt, _dummy_confirmed, token_counter],
outputs=[prompt, _dummy_confirmed, token_counter],
)
@@ -518,7 +518,7 @@ def create_toprow(is_img2img):
paste = gr.Button(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt", visible=opts.trash_prompt_visible)
+ clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id="clear_prompt", visible=opts.clear_prompt_visible)
token_counter = gr.HTML(value=" ", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
@@ -559,7 +559,7 @@ def create_toprow(is_img2img):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
prompt_style2.save_to_config = True
- return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, trash_prompt
+ return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, clear_prompt_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@@ -640,7 +640,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\
txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\
- token_button, trash_prompt_button = create_toprow(is_img2img=False)
+ token_button, clear_prompt_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -716,7 +716,7 @@ def create_ui(wrap_gradio_gpu_call):
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
- trash_prompt_click(trash_prompt_button, txt2img_prompt, dummy_component, token_counter)
+ connect_clear_prompt(clear_prompt_button, txt2img_prompt, dummy_component, token_counter)
txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
@@ -853,7 +853,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\
img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\
- token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True)
+ token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'):
@@ -954,7 +954,7 @@ def create_ui(wrap_gradio_gpu_call):
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
- trash_prompt_click(trash_prompt_button, img2img_prompt, dummy_component, token_counter)
+ connect_clear_prompt(clear_prompt_button, img2img_prompt, dummy_component, token_counter)
img2img_prompt_img.change(
fn=modules.images.image_data,
--
cgit v1.2.3
From 0c7cf08b3d27a61bab4cd8b16f8be8ae74879423 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Fri, 21 Oct 2022 15:32:26 -0500
Subject: some doc and formatting
---
modules/ui.py | 17 ++++++++++++-----
1 file changed, 12 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 31150800..b26cf10a 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -88,7 +88,7 @@ folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
-clear_prompt_symbol = '\U0001F5D1' # 🗑️
+clear_prompt_symbol = '\U0001F5D1' # 🗑️
def plaintext_to_html(text):
@@ -429,12 +429,14 @@ def create_seed_inputs():
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
-
def clear_prompt(_prompt, confirmed, _token_counter):
- if(confirmed):
- return ["", confirmed, update_token_counter("", 1)]
+ """Given confirmation from a user on the client-side, go ahead with clearing prompt"""
+ if confirmed:
+ return ["", confirmed, update_token_counter("", 1)]
+
def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter):
+ """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
button.click(
_js="clear_prompt",
fn=clear_prompt,
@@ -518,7 +520,12 @@ def create_toprow(is_img2img):
paste = gr.Button(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id="clear_prompt", visible=opts.clear_prompt_visible)
+
+ clear_prompt_button = gr.Button(
+ value=clear_prompt_symbol,
+ elem_id="clear_prompt",
+ visible=opts.clear_prompt_visible
+ )
token_counter = gr.HTML(value=" ", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
--
cgit v1.2.3
From 57eb54b838faa383c10079e1bb5471b7bee6a695 Mon Sep 17 00:00:00 2001
From: Extraltodeus
Date: Sat, 22 Oct 2022 00:11:07 +0200
Subject: implement CUDA device selection by ID
---
modules/devices.py | 21 ++++++++++++++++++---
1 file changed, 18 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/devices.py b/modules/devices.py
index eb422583..8a159282 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,7 +1,6 @@
+import sys, os, shlex
import contextlib
-
import torch
-
from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
@@ -9,10 +8,26 @@ has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
+def extract_device_id(args, name):
+ for x in range(len(args)):
+ if name in args[x]: return args[x+1]
+ return None
def get_optimal_device():
if torch.cuda.is_available():
- return torch.device("cuda")
+ # CUDA device selection support:
+ if "shared" not in sys.modules:
+ commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop.
+ sys.argv += shlex.split(commandline_args)
+ device_id = extract_device_id(sys.argv, '--device-id')
+ else:
+ device_id = shared.cmd_opts.device_id
+
+ if device_id is not None:
+ cuda_device = f"cuda:{device_id}"
+ return torch.device(cuda_device)
+ else:
+ return torch.device("cuda")
if has_mps:
return torch.device("mps")
--
cgit v1.2.3
From 29bfacd63cb5c73b9643d94f255cca818fd49d9c Mon Sep 17 00:00:00 2001
From: Extraltodeus
Date: Sat, 22 Oct 2022 00:12:46 +0200
Subject: implement CUDA device selection, --device-id arg
---
modules/shared.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 41d7f08e..03032a47 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -80,6 +80,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
+parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
cmd_opts = parser.parse_args()
restricted_opts = [
--
cgit v1.2.3
From 700340448baa7412c7cc5ff3d1349ac79ee8ed0c Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Fri, 21 Oct 2022 17:24:04 -0500
Subject: forgot to clear neg prompt after moving to back. Add tooltip to hints
---
modules/ui.py | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index b26cf10a..25aeba3b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -429,19 +429,19 @@ def create_seed_inputs():
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
-def clear_prompt(_prompt, confirmed, _token_counter):
+def clear_prompt(_prompt, _prompt_neg, confirmed, _token_counter):
"""Given confirmation from a user on the client-side, go ahead with clearing prompt"""
if confirmed:
- return ["", confirmed, update_token_counter("", 1)]
+ return ["", "", confirmed, update_token_counter("", 1)]
-def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter):
+def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter):
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
button.click(
_js="clear_prompt",
fn=clear_prompt,
- inputs=[prompt, _dummy_confirmed, token_counter],
- outputs=[prompt, _dummy_confirmed, token_counter],
+ inputs=[prompt, prompt_neg, _dummy_confirmed, token_counter],
+ outputs=[prompt, prompt_neg, _dummy_confirmed, token_counter],
)
@@ -723,7 +723,7 @@ def create_ui(wrap_gradio_gpu_call):
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
- connect_clear_prompt(clear_prompt_button, txt2img_prompt, dummy_component, token_counter)
+ connect_clear_prompt(clear_prompt_button, txt2img_prompt, txt2img_negative_prompt, dummy_component, token_counter)
txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
@@ -961,7 +961,7 @@ def create_ui(wrap_gradio_gpu_call):
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
- connect_clear_prompt(clear_prompt_button, img2img_prompt, dummy_component, token_counter)
+ connect_clear_prompt(clear_prompt_button, img2img_prompt, img2img_negative_prompt, dummy_component, token_counter)
img2img_prompt_img.change(
fn=modules.images.image_data,
--
cgit v1.2.3
From 40ddb6df61564684263c7442bacf61efe3882b87 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Sat, 22 Oct 2022 10:16:22 +0800
Subject: inspiration perfected
---
modules/inspiration.py | 71 +++++++++++++++++++++++++++-----------------------
1 file changed, 39 insertions(+), 32 deletions(-)
(limited to 'modules')
diff --git a/modules/inspiration.py b/modules/inspiration.py
index f72ebf3a..319183ab 100644
--- a/modules/inspiration.py
+++ b/modules/inspiration.py
@@ -13,7 +13,7 @@ def read_name_list(file, types=None, keyword=None):
line = line.rstrip("\n")
if types is not None:
dirname = os.path.split(line)
- if dirname[0] in types and keyword in dirname[1]:
+ if dirname[0] in types and keyword in dirname[1].lower():
ret.append(line)
else:
ret.append(line)
@@ -21,8 +21,10 @@ def read_name_list(file, types=None, keyword=None):
return ret
def save_name_list(file, name):
- with open(file, "a") as f:
- f.write(name + "\n")
+ name_list = read_name_list(file)
+ if name not in name_list:
+ with open(file, "a") as f:
+ f.write(name + "\n")
def get_types_list():
files = os.listdir(opts.inspiration_dir)
@@ -39,20 +41,20 @@ def get_types_list():
return types
def get_inspiration_images(source, types, keyword):
+ keyword = keyword.strip(" ").lower()
get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num)
if source == "Favorites":
names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword)
names = random.sample(names, get_num) if len(names) > get_num else names
elif source == "Abandoned":
names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword)
- print(names)
names = random.sample(names, get_num) if len(names) > get_num else names
elif source == "Exclude abandoned":
abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword)
all_names = []
for tp in types:
name_list = os.listdir(os.path.join(opts.inspiration_dir, tp))
- all_names += [os.path.join(tp, x) for x in name_list if keyword in x]
+ all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()]
if len(all_names) > get_num:
names = []
@@ -66,14 +68,14 @@ def get_inspiration_images(source, types, keyword):
all_names = []
for tp in types:
name_list = os.listdir(os.path.join(opts.inspiration_dir, tp))
- all_names += [os.path.join(tp, x) for x in name_list if keyword in x]
+ all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()]
names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names
image_list = []
for a in names:
image_path = os.path.join(opts.inspiration_dir, a)
images = os.listdir(image_path)
image_list.append((os.path.join(image_path, random.choice(images)), a))
- return image_list, names, ""
+ return image_list, names
def select_click(index, name_list):
name = name_list[int(index)]
@@ -83,22 +85,18 @@ def select_click(index, name_list):
def give_up_click(name):
file = os.path.join(inspiration_system_path, "abandoned.txt")
- name_list = read_name_list(file)
- if name not in name_list:
- save_name_list(file, name)
+ save_name_list(file, name)
return "Added to abandoned list"
def collect_click(name):
file = os.path.join(inspiration_system_path, "faverites.txt")
- name_list = read_name_list(file)
- if name not in name_list:
- save_name_list(file, name)
+ save_name_list(file, name)
return "Added to faverite list"
def moveout_click(name, source):
if source == "Abandoned":
file = os.path.join(inspiration_system_path, "abandoned.txt")
- if source == "Favorites":
+ elif source == "Favorites":
file = os.path.join(inspiration_system_path, "faverites.txt")
else:
return None
@@ -107,8 +105,8 @@ def moveout_click(name, source):
with open(file, "a") as f:
for a in name_list:
if a != name:
- f.write(a)
- return "Moved out {name} from {source} list"
+ f.write(a + "\n")
+ return f"Moved out {name} from {source} list"
def source_change(source):
if source in ["Abandoned", "Favorites"]:
@@ -116,10 +114,12 @@ def source_change(source):
else:
return gradio.update(visible=False), []
def add_to_prompt(name, prompt):
- print(name, prompt)
name = os.path.basename(name)
return prompt + "," + name
+def clear_keyword():
+ return ""
+
def ui(gr, opts, txt2img_prompt, img2img_prompt):
with gr.Blocks(analytics_enabled=False) as inspiration:
flag = os.path.exists(opts.inspiration_dir)
@@ -132,15 +132,15 @@ def ui(gr, opts, txt2img_prompt, img2img_prompt):
gr.HTML("""
- """)
+ """)
return inspiration
if not os.path.exists(inspiration_system_path):
os.mkdir(inspiration_system_path)
@@ -148,35 +148,42 @@ def ui(gr, opts, txt2img_prompt, img2img_prompt):
with gr.Column(scale=2):
inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto')
with gr.Column(scale=1):
- print(types)
types = gr.CheckboxGroup(choices=types, value=types)
- keyword = gr.Textbox("", label="Key word")
- with gr.Row():
- source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source")
- get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button")
- name = gr.Textbox(show_label=False, interactive=False)
with gr.Row():
+ source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source")
+ keyword = gr.Textbox("", label="Key word")
+ get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button")
+ name = gr.Textbox(show_label=False, interactive=False)
+ with gr.Row():
send_to_txt2img = gr.Button('to txt2img')
send_to_img2img = gr.Button('to img2img')
style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto')
- collect = gr.Button('Collect')
- give_up = gr.Button("Don't show again")
- moveout = gr.Button("Move out", visible=False)
warning = gr.HTML()
+ with gr.Row():
+ collect = gr.Button('Collect')
+ give_up = gr.Button("Don't show again")
+ moveout = gr.Button("Move out", visible=False)
+
with gr.Row(visible=False):
select_button = gr.Button('set button', elem_id="inspiration_select_button")
name_list = gr.State()
- get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list, keyword])
- source.change(source_change, inputs=[source], outputs=[moveout, style_gallery])
- source.change(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None)
+ get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list])
keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None)
+ source.change(source_change, inputs=[source], outputs=[moveout, style_gallery])
+ source.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword])
+ types.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword])
+
select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning])
give_up.click(give_up_click, inputs=[name], outputs=[warning])
collect.click(collect_click, inputs=[name], outputs=[warning])
moveout.click(moveout_click, inputs=[name, source], outputs=[warning])
+ moveout.click(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None)
+
send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt])
send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt])
+ send_to_txt2img.click(collect_click, inputs=[name], outputs=[warning])
+ send_to_img2img.click(collect_click, inputs=[name], outputs=[warning])
send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None)
send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None)
return inspiration
--
cgit v1.2.3
From d93ea5cdeb2fd3607b7265271ccab2c9bf4c1156 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Sat, 22 Oct 2022 10:21:21 +0800
Subject: inspiration perfected
---
modules/inspiration.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/inspiration.py b/modules/inspiration.py
index 319183ab..94ff139a 100644
--- a/modules/inspiration.py
+++ b/modules/inspiration.py
@@ -73,8 +73,11 @@ def get_inspiration_images(source, types, keyword):
image_list = []
for a in names:
image_path = os.path.join(opts.inspiration_dir, a)
- images = os.listdir(image_path)
- image_list.append((os.path.join(image_path, random.choice(images)), a))
+ images = os.listdir(image_path)
+ if len(images) > 0:
+ image_list.append((os.path.join(image_path, random.choice(images)), a))
+ else:
+ print(image_path)
return image_list, names
def select_click(index, name_list):
--
cgit v1.2.3
From 67b78f0ea6f196bfdca49932da062631bb40d0b1 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Sat, 22 Oct 2022 10:29:23 +0800
Subject: inspiration perfected
---
modules/inspiration.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/inspiration.py b/modules/inspiration.py
index 94ff139a..29cf8297 100644
--- a/modules/inspiration.py
+++ b/modules/inspiration.py
@@ -160,12 +160,13 @@ def ui(gr, opts, txt2img_prompt, img2img_prompt):
with gr.Row():
send_to_txt2img = gr.Button('to txt2img')
send_to_img2img = gr.Button('to img2img')
- style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto')
- warning = gr.HTML()
- with gr.Row():
collect = gr.Button('Collect')
give_up = gr.Button("Don't show again")
- moveout = gr.Button("Move out", visible=False)
+ moveout = gr.Button("Move out", visible=False)
+ warning = gr.HTML()
+ style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto')
+
+
with gr.Row(visible=False):
select_button = gr.Button('set button', elem_id="inspiration_select_button")
--
cgit v1.2.3
From 2b91251637078e04472c91a06a8d9c4db9c1dcf0 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 12:23:45 +0300
Subject: removed aesthetic gradients as built-in added support for extensions
---
modules/aesthetic_clip.py | 241 --------------------------------------------
modules/images_history.py | 2 +-
modules/img2img.py | 5 +-
modules/processing.py | 35 ++++---
modules/script_callbacks.py | 42 ++++++++
modules/scripts.py | 210 ++++++++++++++++++++++++++++----------
modules/sd_hijack.py | 1 -
modules/sd_models.py | 7 +-
modules/shared.py | 19 ----
modules/txt2img.py | 5 +-
modules/ui.py | 83 +++------------
11 files changed, 244 insertions(+), 406 deletions(-)
delete mode 100644 modules/aesthetic_clip.py
create mode 100644 modules/script_callbacks.py
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
deleted file mode 100644
index 8c828541..00000000
--- a/modules/aesthetic_clip.py
+++ /dev/null
@@ -1,241 +0,0 @@
-import copy
-import itertools
-import os
-from pathlib import Path
-import html
-import gc
-
-import gradio as gr
-import torch
-from PIL import Image
-from torch import optim
-
-from modules import shared
-from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
-from tqdm.auto import tqdm, trange
-from modules.shared import opts, device
-
-
-def get_all_images_in_folder(folder):
- return [os.path.join(folder, f) for f in os.listdir(folder) if
- os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
-
-
-def check_is_valid_image_file(filename):
- return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))
-
-
-def batched(dataset, total, n=1):
- for ndx in range(0, total, n):
- yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
-
-
-def iter_to_batched(iterable, n=1):
- it = iter(iterable)
- while True:
- chunk = tuple(itertools.islice(it, n))
- if not chunk:
- return
- yield chunk
-
-
-def create_ui():
- import modules.ui
-
- with gr.Group():
- with gr.Accordion("Open for Clip Aesthetic!", open=False):
- with gr.Row():
- aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
- value=0.9)
- aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
-
- with gr.Row():
- aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
- placeholder="Aesthetic learning rate", value="0.0001")
- aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
- aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
- label="Aesthetic imgs embedding",
- value="None")
-
- modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
-
- with gr.Row():
- aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
- placeholder="This text is used to rotate the feature space of the imgs embs",
- value="")
- aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
- value=0.1)
- aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
-
- return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
-
-
-aesthetic_clip_model = None
-
-
-def aesthetic_clip():
- global aesthetic_clip_model
-
- if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
- aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
- aesthetic_clip_model.cpu()
-
- return aesthetic_clip_model
-
-
-def generate_imgs_embd(name, folder, batch_size):
- model = aesthetic_clip().to(device)
- processor = CLIPProcessor.from_pretrained(model.name_or_path)
-
- with torch.no_grad():
- embs = []
- for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
- desc=f"Generating embeddings for {name}"):
- if shared.state.interrupted:
- break
- inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
- outputs = model.get_image_features(**inputs).cpu()
- embs.append(torch.clone(outputs))
- inputs.to("cpu")
- del inputs, outputs
-
- embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
-
- # The generated embedding will be located here
- path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
- torch.save(embs, path)
-
- model.cpu()
- del processor
- del embs
- gc.collect()
- torch.cuda.empty_cache()
- res = f"""
- Done generating embedding for {name}!
- Aesthetic embedding saved to {html.escape(path)}
- """
- shared.update_aesthetic_embeddings()
- return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
- value="None"), \
- gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
- label="Imgs embedding",
- value="None"), res, ""
-
-
-def slerp(low, high, val):
- low_norm = low / torch.norm(low, dim=1, keepdim=True)
- high_norm = high / torch.norm(high, dim=1, keepdim=True)
- omega = torch.acos((low_norm * high_norm).sum(1))
- so = torch.sin(omega)
- res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
- return res
-
-
-class AestheticCLIP:
- def __init__(self):
- self.skip = False
- self.aesthetic_steps = 0
- self.aesthetic_weight = 0
- self.aesthetic_lr = 0
- self.slerp = False
- self.aesthetic_text_negative = ""
- self.aesthetic_slerp_angle = 0
- self.aesthetic_imgs_text = ""
-
- self.image_embs_name = None
- self.image_embs = None
- self.load_image_embs(None)
-
- def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
- aesthetic_slerp=True, aesthetic_imgs_text="",
- aesthetic_slerp_angle=0.15,
- aesthetic_text_negative=False):
- self.aesthetic_imgs_text = aesthetic_imgs_text
- self.aesthetic_slerp_angle = aesthetic_slerp_angle
- self.aesthetic_text_negative = aesthetic_text_negative
- self.slerp = aesthetic_slerp
- self.aesthetic_lr = aesthetic_lr
- self.aesthetic_weight = aesthetic_weight
- self.aesthetic_steps = aesthetic_steps
- self.load_image_embs(image_embs_name)
-
- if self.image_embs_name is not None:
- p.extra_generation_params.update({
- "Aesthetic LR": aesthetic_lr,
- "Aesthetic weight": aesthetic_weight,
- "Aesthetic steps": aesthetic_steps,
- "Aesthetic embedding": self.image_embs_name,
- "Aesthetic slerp": aesthetic_slerp,
- "Aesthetic text": aesthetic_imgs_text,
- "Aesthetic text negative": aesthetic_text_negative,
- "Aesthetic slerp angle": aesthetic_slerp_angle,
- })
-
- def set_skip(self, skip):
- self.skip = skip
-
- def load_image_embs(self, image_embs_name):
- if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
- image_embs_name = None
- self.image_embs_name = None
- if image_embs_name is not None and self.image_embs_name != image_embs_name:
- self.image_embs_name = image_embs_name
- self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
- self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
- self.image_embs.requires_grad_(False)
-
- def __call__(self, z, remade_batch_tokens):
- if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
- tokenizer = shared.sd_model.cond_stage_model.tokenizer
- if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [
- [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
- remade_batch_tokens]
-
- tokens = torch.asarray(remade_batch_tokens).to(device)
-
- model = copy.deepcopy(aesthetic_clip()).to(device)
- model.requires_grad_(True)
- if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
- text_embs_2 = model.get_text_features(
- **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
- if self.aesthetic_text_negative:
- text_embs_2 = self.image_embs - text_embs_2
- text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
- img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
- else:
- img_embs = self.image_embs
-
- with torch.enable_grad():
-
- # We optimize the model to maximize the similarity
- optimizer = optim.Adam(
- model.text_model.parameters(), lr=self.aesthetic_lr
- )
-
- for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
- text_embs = model.get_text_features(input_ids=tokens)
- text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
- sim = text_embs @ img_embs.T
- loss = -sim
- optimizer.zero_grad()
- loss.mean().backward()
- optimizer.step()
-
- zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
- if opts.CLIP_stop_at_last_layers > 1:
- zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
- zn = model.text_model.final_layer_norm(zn)
- else:
- zn = zn.last_hidden_state
- model.cpu()
- del model
- gc.collect()
- torch.cuda.empty_cache()
- zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
- if self.slerp:
- z = slerp(z, zn, self.aesthetic_weight)
- else:
- z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
-
- return z
diff --git a/modules/images_history.py b/modules/images_history.py
index 78fd0543..bc5cf11f 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -310,7 +310,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
forward = gr.Button('Prev batch')
backward = gr.Button('Next batch')
with gr.Column(scale=3):
- load_info = gr.HTML(visible=not custom_dir)
+ load_info = gr.HTML(visible=not custom_dir)
with gr.Row(visible=False) as warning:
warning_box = gr.Textbox("Message", interactive=False)
diff --git a/modules/img2img.py b/modules/img2img.py
index eea5199b..8d9f7cf9 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_inpaint = mode == 1
is_batch = mode == 2
@@ -109,7 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert,
)
- shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/processing.py b/modules/processing.py
index ff1ec4c9..372489f7 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -104,6 +104,12 @@ class StableDiffusionProcessing():
self.seed_resize_from_h = 0
self.seed_resize_from_w = 0
+ self.scripts = None
+ self.script_args = None
+ self.all_prompts = None
+ self.all_seeds = None
+ self.all_subseeds = None
+
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -350,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list:
- all_prompts = p.prompt
+ p.all_prompts = p.prompt
else:
- all_prompts = p.batch_size * p.n_iter * [p.prompt]
+ p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
if type(seed) == list:
- all_seeds = seed
+ p.all_seeds = seed
else:
- all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
if type(subseed) == list:
- all_subseeds = subseed
+ p.all_subseeds = subseed
else:
- all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
+ if p.scripts is not None:
+ p.scripts.run_alwayson_scripts(p)
+
infotexts = []
output_images = []
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
- p.init(all_prompts, all_seeds, all_subseeds)
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
if state.job_count == -1:
state.job_count = p.n_iter
@@ -387,9 +396,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
- prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+ subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if (len(prompts) == 0):
break
@@ -490,10 +499,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+ return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
new file mode 100644
index 00000000..866b7acd
--- /dev/null
+++ b/modules/script_callbacks.py
@@ -0,0 +1,42 @@
+
+callbacks_model_loaded = []
+callbacks_ui_tabs = []
+
+
+def clear_callbacks():
+ callbacks_model_loaded.clear()
+ callbacks_ui_tabs.clear()
+
+
+def model_loaded_callback(sd_model):
+ for callback in callbacks_model_loaded:
+ callback(sd_model)
+
+
+def ui_tabs_callback():
+ res = []
+
+ for callback in callbacks_ui_tabs:
+ res += callback() or []
+
+ return res
+
+
+def on_model_loaded(callback):
+ """register a function to be called when the stable diffusion model is created; the model is
+ passed as an argument"""
+ callbacks_model_loaded.append(callback)
+
+
+def on_ui_tabs(callback):
+ """register a function to be called when the UI is creating new tabs.
+ The function must either return a None, which means no new tabs to be added, or a list, where
+ each element is a tuple:
+ (gradio_component, title, elem_id)
+
+ gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
+ title is tab text displayed to user in the UI
+ elem_id is HTML id for the tab
+ """
+ callbacks_ui_tabs.append(callback)
+
diff --git a/modules/scripts.py b/modules/scripts.py
index 1039fa9c..65f25f49 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,86 +1,153 @@
import os
import sys
import traceback
+from collections import namedtuple
import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared
+from modules import shared, paths, script_callbacks
+
+AlwaysVisible = object()
+
class Script:
filename = None
args_from = None
args_to = None
+ alwayson = False
+
+ infotext_fields = None
+ """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
+ parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
+ """
- # The title of the script. This is what will be displayed in the dropdown menu.
def title(self):
+ """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
+
raise NotImplementedError()
- # How the script is displayed in the UI. See https://gradio.app/docs/#components
- # for the different UI components you can use and how to create them.
- # Most UI components can return a value, such as a boolean for a checkbox.
- # The returned values are passed to the run method as parameters.
def ui(self, is_img2img):
+ """this function should create gradio UI elements. See https://gradio.app/docs/#components
+ The return value should be an array of all components that are used in processing.
+ Values of those returned componenbts will be passed to run() and process() functions.
+ """
+
pass
- # Determines when the script should be shown in the dropdown menu via the
- # returned value. As an example:
- # is_img2img is True if the current tab is img2img, and False if it is txt2img.
- # Thus, return is_img2img to only show the script on the img2img tab.
def show(self, is_img2img):
+ """
+ is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
+
+ This function should return:
+ - False if the script should not be shown in UI at all
+ - True if the script should be shown in UI if it's scelected in the scripts drowpdown
+ - script.AlwaysVisible if the script should be shown in UI at all times
+ """
+
return True
- # This is where the additional processing is implemented. The parameters include
- # self, the model object "p" (a StableDiffusionProcessing class, see
- # processing.py), and the parameters returned by the ui method.
- # Custom functions can be defined here, and additional libraries can be imported
- # to be used in processing. The return value should be a Processed object, which is
- # what is returned by the process_images method.
- def run(self, *args):
+ def run(self, p, *args):
+ """
+ This function is called if the script has been selected in the script dropdown.
+ It must do all processing and return the Processed object with results, same as
+ one returned by processing.process_images.
+
+ Usually the processing is done by calling the processing.process_images function.
+
+ args contains all values returned by components from ui()
+ """
+
raise NotImplementedError()
- # The description method is currently unused.
- # To add a description that appears when hovering over the title, amend the "titles"
- # dict in script.js to include the script title (returned by title) as a key, and
- # your description as the value.
+ def process(self, p, *args):
+ """
+ This function is called before processing begins for AlwaysVisible scripts.
+ scripts. You can modify the processing object (p) here, inject hooks, etc.
+ """
+
+ pass
+
def describe(self):
+ """unused"""
return ""
+current_basedir = paths.script_path
+
+
+def basedir():
+ """returns the base directory for the current script. For scripts in the main scripts directory,
+ this is the main directory (where webui.py resides), and for scripts in extensions directory
+ (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
+ """
+ return current_basedir
+
+
scripts_data = []
+ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
+ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
+
+
+def list_scripts(scriptdirname, extension):
+ scripts_list = []
+
+ basedir = os.path.join(paths.script_path, scriptdirname)
+ if os.path.exists(basedir):
+ for filename in sorted(os.listdir(basedir)):
+ scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
+
+ extdir = os.path.join(paths.script_path, "extensions")
+ if os.path.exists(extdir):
+ for dirname in sorted(os.listdir(extdir)):
+ dirpath = os.path.join(extdir, dirname)
+ if not os.path.isdir(dirpath):
+ continue
+ for filename in sorted(os.listdir(os.path.join(dirpath, scriptdirname))):
+ scripts_list.append(ScriptFile(dirpath, filename, os.path.join(dirpath, scriptdirname, filename)))
-def load_scripts(basedir):
- if not os.path.exists(basedir):
- return
+ scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
- for filename in sorted(os.listdir(basedir)):
- path = os.path.join(basedir, filename)
+ return scripts_list
- if os.path.splitext(path)[1].lower() != '.py':
- continue
- if not os.path.isfile(path):
- continue
+def load_scripts():
+ global current_basedir
+ scripts_data.clear()
+ script_callbacks.clear_callbacks()
+
+ scripts_list = list_scripts("scripts", ".py")
+
+ syspath = sys.path
+ for scriptfile in sorted(scripts_list):
try:
- with open(path, "r", encoding="utf8") as file:
+ if scriptfile.basedir != paths.script_path:
+ sys.path = [scriptfile.basedir] + sys.path
+ current_basedir = scriptfile.basedir
+
+ with open(scriptfile.path, "r", encoding="utf8") as file:
text = file.read()
from types import ModuleType
- compiled = compile(text, path, 'exec')
- module = ModuleType(filename)
+ compiled = compile(text, scriptfile.path, 'exec')
+ module = ModuleType(scriptfile.filename)
exec(compiled, module.__dict__)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
- scripts_data.append((script_class, path))
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
except Exception:
- print(f"Error loading script: {filename}", file=sys.stderr)
+ print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ finally:
+ sys.path = syspath
+ current_basedir = paths.script_path
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
@@ -96,56 +163,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptRunner:
def __init__(self):
self.scripts = []
+ self.selectable_scripts = []
+ self.alwayson_scripts = []
self.titles = []
+ self.infotext_fields = []
def setup_ui(self, is_img2img):
- for script_class, path in scripts_data:
+ for script_class, path, basedir in scripts_data:
script = script_class()
script.filename = path
- if not script.show(is_img2img):
- continue
+ visibility = script.show(is_img2img)
- self.scripts.append(script)
+ if visibility == AlwaysVisible:
+ self.scripts.append(script)
+ self.alwayson_scripts.append(script)
+ script.alwayson = True
- self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
+ elif visibility:
+ self.scripts.append(script)
+ self.selectable_scripts.append(script)
- dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
- dropdown.save_to_config = True
- inputs = [dropdown]
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
+
+ inputs = [None]
+ inputs_alwayson = [True]
- for script in self.scripts:
+ def create_script_ui(script, inputs, inputs_alwayson):
script.args_from = len(inputs)
script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
if controls is None:
- continue
+ return
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
- control.visible = False
+ if not script.alwayson:
+ control.visible = False
+
+ if script.infotext_fields is not None:
+ self.infotext_fields += script.infotext_fields
inputs += controls
+ inputs_alwayson += [script.alwayson for _ in controls]
script.args_to = len(inputs)
+ for script in self.alwayson_scripts:
+ with gr.Group():
+ create_script_ui(script, inputs, inputs_alwayson)
+
+ dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
+ dropdown.save_to_config = True
+ inputs[0] = dropdown
+
+ for script in self.selectable_scripts:
+ create_script_ui(script, inputs, inputs_alwayson)
+
def select_script(script_index):
- if 0 < script_index <= len(self.scripts):
- script = self.scripts[script_index-1]
+ if 0 < script_index <= len(self.selectable_scripts):
+ script = self.selectable_scripts[script_index-1]
args_from = script.args_from
args_to = script.args_to
else:
args_from = 0
args_to = 0
- return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
+ return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
def init_field(title):
if title == 'None':
return
script_index = self.titles.index(title)
- script = self.scripts[script_index]
+ script = self.selectable_scripts[script_index]
for i in range(script.args_from, script.args_to):
inputs[i].visible = True
@@ -164,7 +255,7 @@ class ScriptRunner:
if script_index == 0:
return None
- script = self.scripts[script_index-1]
+ script = self.selectable_scripts[script_index-1]
if script is None:
return None
@@ -176,6 +267,15 @@ class ScriptRunner:
return processed
+ def run_alwayson_scripts(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.process(p, *script_args)
+ except Exception:
+ print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def reload_sources(self):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
@@ -197,19 +297,21 @@ class ScriptRunner:
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
+
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
def reload_script_body_only():
scripts_txt2img.reload_sources()
scripts_img2img.reload_sources()
-def reload_scripts(basedir):
+def reload_scripts():
global scripts_txt2img, scripts_img2img
- scripts_data.clear()
- load_scripts(basedir)
+ load_scripts()
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 1f8587d1..0f10828e 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -332,7 +332,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
- z1 = shared.aesthetic_clip(z1, remade_batch_tokens)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
diff --git a/modules/sd_models.py b/modules/sd_models.py
index d99dbce8..f9b3063d 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -7,7 +7,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices
+from modules import shared, modelloader, devices, script_callbacks
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
@@ -238,6 +238,9 @@ def load_model(checkpoint_info=None):
sd_hijack.model_hijack.hijack(sd_model)
sd_model.eval()
+ shared.sd_model = sd_model
+
+ script_callbacks.model_loaded_callback(sd_model)
print(f"Model loaded.")
return sd_model
@@ -252,7 +255,7 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
- shared.sd_model = load_model(checkpoint_info)
+ load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
diff --git a/modules/shared.py b/modules/shared.py
index 0dbe360d..7d786f07 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -31,7 +31,6 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
-parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(models_path, 'aesthetic_embeddings'), help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
@@ -109,21 +108,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
-
-os.makedirs(cmd_opts.aesthetic_embeddings_dir, exist_ok=True)
-aesthetic_embeddings = {}
-
-
-def update_aesthetic_embeddings():
- global aesthetic_embeddings
- aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
- os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
- aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings)
-
-
-update_aesthetic_embeddings()
-
-
def reload_hypernetworks():
global hypernetworks
@@ -415,9 +399,6 @@ sd_model = None
clip_model = None
-from modules.aesthetic_clip import AestheticCLIP
-aesthetic_clip = AestheticCLIP()
-
progress_print_out = sys.stdout
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 1761cfa2..c9d5a090 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -7,7 +7,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -36,7 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None,
)
- shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/ui.py b/modules/ui.py
index 70a9cf10..c977482c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -23,10 +23,10 @@ import gradio as gr
import gradio.utils
import gradio.routes
-from modules import sd_hijack, sd_models, localization
+from modules import sd_hijack, sd_models, localization, script_callbacks
from modules.paths import script_path
-from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings
+from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
@@ -44,7 +44,6 @@ from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui
-import modules.aesthetic_clip as aesthetic_clip
import modules.images_history as img_his
@@ -662,8 +661,6 @@ def create_ui(wrap_gradio_gpu_call):
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
- aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui()
-
with gr.Group():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
@@ -718,14 +715,6 @@ def create_ui(wrap_gradio_gpu_call):
denoising_strength,
firstphase_width,
firstphase_height,
- aesthetic_lr,
- aesthetic_weight,
- aesthetic_steps,
- aesthetic_imgs,
- aesthetic_slerp,
- aesthetic_imgs_text,
- aesthetic_slerp_angle,
- aesthetic_text_negative
] + custom_inputs,
outputs=[
@@ -804,14 +793,7 @@ def create_ui(wrap_gradio_gpu_call):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
- (aesthetic_lr, "Aesthetic LR"),
- (aesthetic_weight, "Aesthetic weight"),
- (aesthetic_steps, "Aesthetic steps"),
- (aesthetic_imgs, "Aesthetic embedding"),
- (aesthetic_slerp, "Aesthetic slerp"),
- (aesthetic_imgs_text, "Aesthetic text"),
- (aesthetic_text_negative, "Aesthetic text negative"),
- (aesthetic_slerp_angle, "Aesthetic slerp angle"),
+ *modules.scripts.scripts_txt2img.infotext_fields
]
txt2img_preview_params = [
@@ -896,8 +878,6 @@ def create_ui(wrap_gradio_gpu_call):
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
- aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui()
-
with gr.Group():
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
@@ -988,14 +968,6 @@ def create_ui(wrap_gradio_gpu_call):
inpainting_mask_invert,
img2img_batch_input_dir,
img2img_batch_output_dir,
- aesthetic_lr_im,
- aesthetic_weight_im,
- aesthetic_steps_im,
- aesthetic_imgs_im,
- aesthetic_slerp_im,
- aesthetic_imgs_text_im,
- aesthetic_slerp_angle_im,
- aesthetic_text_negative_im,
] + custom_inputs,
outputs=[
img2img_gallery,
@@ -1087,14 +1059,7 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
- (aesthetic_lr_im, "Aesthetic LR"),
- (aesthetic_weight_im, "Aesthetic weight"),
- (aesthetic_steps_im, "Aesthetic steps"),
- (aesthetic_imgs_im, "Aesthetic embedding"),
- (aesthetic_slerp_im, "Aesthetic slerp"),
- (aesthetic_imgs_text_im, "Aesthetic text"),
- (aesthetic_text_negative_im, "Aesthetic text negative"),
- (aesthetic_slerp_angle_im, "Aesthetic slerp angle"),
+ *modules.scripts.scripts_img2img.infotext_fields
]
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
@@ -1217,9 +1182,9 @@ def create_ui(wrap_gradio_gpu_call):
)
#images history
images_history_switch_dict = {
- "fn":modules.generation_parameters_copypaste.connect_paste,
- "t2i":txt2img_paste_fields,
- "i2i":img2img_paste_fields
+ "fn": modules.generation_parameters_copypaste.connect_paste,
+ "t2i": txt2img_paste_fields,
+ "i2i": img2img_paste_fields
}
images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
@@ -1264,18 +1229,6 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary')
- with gr.Tab(label="Create aesthetic images embedding"):
-
- new_embedding_name_ae = gr.Textbox(label="Name")
- process_src_ae = gr.Textbox(label='Source directory')
- batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
- with gr.Row():
- with gr.Column(scale=3):
- gr.HTML(value="")
-
- with gr.Column():
- create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')
-
with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
@@ -1375,21 +1328,6 @@ def create_ui(wrap_gradio_gpu_call):
]
)
- create_embedding_ae.click(
- fn=aesthetic_clip.generate_imgs_embd,
- inputs=[
- new_embedding_name_ae,
- process_src_ae,
- batch_ae
- ],
- outputs=[
- aesthetic_imgs,
- aesthetic_imgs_im,
- ti_output,
- ti_outcome,
- ]
- )
-
create_hypernetwork.click(
fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[
@@ -1580,10 +1518,10 @@ Requested path was: {f}
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
+ oldval = opts.data.get(key, None)
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
return gr.update(value=oldval), opts.dumpjson()
- oldval = opts.data.get(key, None)
opts.data[key] = value
if oldval != value:
@@ -1692,9 +1630,12 @@ Requested path was: {f}
(images_history, "Image Browser", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
- (settings_interface, "Settings", "settings"),
]
+ interfaces += script_callbacks.ui_tabs_callback()
+
+ interfaces += [(settings_interface, "Settings", "settings")]
+
with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
css = file.read()
--
cgit v1.2.3
From 6398dc9b1049f242576ca309f95a3fb1e654951c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 13:34:49 +0300
Subject: further support for extensions
---
modules/scripts.py | 44 +++++++++++++++++++++++++++++++++++---------
modules/ui.py | 19 ++++++++++---------
2 files changed, 45 insertions(+), 18 deletions(-)
(limited to 'modules')
diff --git a/modules/scripts.py b/modules/scripts.py
index 65f25f49..9323af3e 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -102,17 +102,39 @@ def list_scripts(scriptdirname, extension):
if os.path.exists(extdir):
for dirname in sorted(os.listdir(extdir)):
dirpath = os.path.join(extdir, dirname)
- if not os.path.isdir(dirpath):
+ scriptdirpath = os.path.join(dirpath, scriptdirname)
+
+ if not os.path.isdir(scriptdirpath):
continue
- for filename in sorted(os.listdir(os.path.join(dirpath, scriptdirname))):
- scripts_list.append(ScriptFile(dirpath, filename, os.path.join(dirpath, scriptdirname, filename)))
+ for filename in sorted(os.listdir(scriptdirpath)):
+ scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
return scripts_list
+def list_files_with_name(filename):
+ res = []
+
+ dirs = [paths.script_path]
+
+ extdir = os.path.join(paths.script_path, "extensions")
+ if os.path.exists(extdir):
+ dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
+
+ for dirpath in dirs:
+ if not os.path.isdir(dirpath):
+ continue
+
+ path = os.path.join(dirpath, filename)
+ if os.path.isfile(filename):
+ res.append(path)
+
+ return res
+
+
def load_scripts():
global current_basedir
scripts_data.clear()
@@ -276,7 +298,7 @@ class ScriptRunner:
print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- def reload_sources(self):
+ def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from
@@ -286,9 +308,12 @@ class ScriptRunner:
from types import ModuleType
- compiled = compile(text, filename, 'exec')
- module = ModuleType(script.filename)
- exec(compiled, module.__dict__)
+ module = cache.get(filename, None)
+ if module is None:
+ compiled = compile(text, filename, 'exec')
+ module = ModuleType(script.filename)
+ exec(compiled, module.__dict__)
+ cache[filename] = module
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
@@ -303,8 +328,9 @@ scripts_img2img = ScriptRunner()
def reload_script_body_only():
- scripts_txt2img.reload_sources()
- scripts_img2img.reload_sources()
+ cache = {}
+ scripts_txt2img.reload_sources(cache)
+ scripts_img2img.reload_sources(cache)
def reload_scripts():
diff --git a/modules/ui.py b/modules/ui.py
index c977482c..29986124 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1636,13 +1636,15 @@ Requested path was: {f}
interfaces += [(settings_interface, "Settings", "settings")]
- with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
- css = file.read()
+ css = ""
+
+ for cssfile in modules.scripts.list_files_with_name("style.css"):
+ with open(cssfile, "r", encoding="utf8") as file:
+ css += file.read() + "\n"
if os.path.exists(os.path.join(script_path, "user.css")):
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
- usercss = file.read()
- css += usercss
+ css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
@@ -1865,9 +1867,9 @@ def load_javascript(raw_response):
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f''
- jsdir = os.path.join(script_path, "javascript")
- for filename in sorted(os.listdir(jsdir)):
- with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
+ scripts_list = modules.scripts.list_scripts("javascript", ".js")
+ for basedir, filename, path in scripts_list:
+ with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n"
if cmd_opts.theme is not None:
@@ -1885,6 +1887,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response
-reload_javascript = partial(load_javascript,
- gradio.routes.templates.TemplateResponse)
+reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
reload_javascript()
--
cgit v1.2.3
From 50b5504401e50b6c94eba41b37fe212b2f27b792 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 14:04:14 +0300
Subject: remove parsing command line from devices.py
---
modules/devices.py | 14 +++++---------
modules/lowvram.py | 9 ++++-----
2 files changed, 9 insertions(+), 14 deletions(-)
(limited to 'modules')
diff --git a/modules/devices.py b/modules/devices.py
index 8a159282..dc1f3cdd 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -15,14 +15,10 @@ def extract_device_id(args, name):
def get_optimal_device():
if torch.cuda.is_available():
- # CUDA device selection support:
- if "shared" not in sys.modules:
- commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop.
- sys.argv += shlex.split(commandline_args)
- device_id = extract_device_id(sys.argv, '--device-id')
- else:
- device_id = shared.cmd_opts.device_id
-
+ from modules import shared
+
+ device_id = shared.cmd_opts.device_id
+
if device_id is not None:
cuda_device = f"cuda:{device_id}"
return torch.device(cuda_device)
@@ -49,7 +45,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 7eba1349..f327c3df 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,9 +1,8 @@
import torch
-from modules.devices import get_optimal_device
+from modules import devices
module_in_gpu = None
cpu = torch.device("cpu")
-device = gpu = get_optimal_device()
def send_everything_to_cpu():
@@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram):
if module_in_gpu is not None:
module_in_gpu.to(cpu)
- module.to(gpu)
+ module.to(devices.device)
module_in_gpu = module
# see below for register_forward_pre_hook;
@@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
- sd_model.to(device)
+ sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models
@@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
- sd_model.model.to(device)
+ sd_model.model.to(devices.device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
--
cgit v1.2.3
From 0e8ca8e7af05be22d7d2c07a47c3c7febe0f0ab6 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Sat, 22 Oct 2022 11:07:00 +0000
Subject: add dropout
---
modules/hypernetworks/hypernetwork.py | 68 +++++++++++++++++++++--------------
modules/hypernetworks/ui.py | 10 +++---
modules/ui.py | 43 +++++++++++-----------
3 files changed, 70 insertions(+), 51 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 905cbeef..e493f366 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,47 +1,60 @@
+import csv
import datetime
import glob
import html
import os
import sys
import traceback
-import tqdm
-import csv
+import modules.textual_inversion.dataset
import torch
-
-from ldm.util import default
-from modules import devices, shared, processing, sd_models
-import torch
-from torch import einsum
+import tqdm
from einops import rearrange, repeat
-import modules.textual_inversion.dataset
+from ldm.util import default
+from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from torch import einsum
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
- activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU,
- "swish": torch.nn.Hardswish}
-
- def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
+ activation_dict = {
+ "relu": torch.nn.ReLU,
+ "leakyrelu": torch.nn.LeakyReLU,
+ "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish,
+ }
+
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
-
+ assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
+
linears = []
for i in range(len(layer_structure) - 1):
+
+ # Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
- # if skip_first_layer because first parameters potentially contain negative values
- # if i < 1: continue
- if activation_func in HypernetworkModule.activation_dict:
- linears.append(HypernetworkModule.activation_dict[activation_func]())
+
+ # Add an activation func
+ if activation_func == "linear":
+ pass
+ elif activation_func in self.activation_dict:
+ linears.append(self.activation_dict[activation_func]())
else:
- print("Invalid key {} encountered as activation function!".format(activation_func))
- # if use_dropout:
- # linears.append(torch.nn.Dropout(p=0.3))
+ raise NotImplementedError(
+ "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
+ )
+
+ # Add dropout
+ if use_dropout:
+ linears.append(torch.nn.Dropout(p=0.3))
+
+ # Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -93,7 +106,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
self.filename = None
self.name = name
self.layers = {}
@@ -101,13 +114,14 @@ class Hypernetwork:
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
- self.add_layer_norm = add_layer_norm
self.activation_func = activation_func
+ self.add_layer_norm = add_layer_norm
+ self.use_dropout = use_dropout
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
)
def weights(self):
@@ -129,8 +143,9 @@ class Hypernetwork:
state_dict['step'] = self.step
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
- state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['activation_func'] = self.activation_func
+ state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -144,8 +159,9 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
- self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.activation_func = state_dict.get('activation_func', None)
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ self.use_dropout = state_dict.get('use_dropout', False)
for size, sd in state_dict.items():
if type(size) == int:
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 1a5a27d8..5f6f17b6 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -3,14 +3,13 @@ import os
import re
import gradio as gr
-
-import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared, devices
+import modules.textual_inversion.textual_inversion
+from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
+def create_hypernetwork(name, enable_sizes, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
@@ -21,8 +20,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
name=name,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
- add_layer_norm=add_layer_norm,
activation_func=activation_func,
+ add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
)
hypernet.save(fn)
diff --git a/modules/ui.py b/modules/ui.py
index 716f14b8..d4b32c05 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -5,43 +5,44 @@ import json
import math
import mimetypes
import os
+import platform
import random
+import subprocess as sp
import sys
import tempfile
import time
import traceback
-import platform
-import subprocess as sp
from functools import partial, reduce
+import gradio as gr
+import gradio.routes
+import gradio.utils
import numpy as np
+import piexif
import torch
from PIL import Image, PngImagePlugin
-import piexif
-import gradio as gr
-import gradio.utils
-import gradio.routes
-
-from modules import sd_hijack, sd_models, localization
+from modules import localization, sd_hijack, sd_models
from modules.paths import script_path
-from modules.shared import opts, cmd_opts, restricted_opts
+from modules.shared import cmd_opts, opts, restricted_opts
+
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
-import modules.shared as shared
-from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.sd_hijack import model_hijack
+
+import modules.codeformer_model
+import modules.generation_parameters_copypaste
+import modules.gfpgan_model
+import modules.hypernetworks.ui
+import modules.images_history as img_his
import modules.ldsr_model
import modules.scripts
-import modules.gfpgan_model
-import modules.codeformer_model
+import modules.shared as shared
import modules.styles
-import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
from modules import prompt_parser
from modules.images import save_image
-import modules.textual_inversion.ui
-import modules.hypernetworks.ui
-import modules.images_history as img_his
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -1223,8 +1224,9 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
with gr.Row():
with gr.Column(scale=3):
@@ -1308,8 +1310,9 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name,
new_hypernetwork_sizes,
new_hypernetwork_layer_structure,
- new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func,
+ new_hypernetwork_add_layer_norm,
+ new_hypernetwork_use_dropout
],
outputs=[
train_hypernetwork_name,
--
cgit v1.2.3
From 1cd3ed7def40198f46d30f74dd37d2906ebdbaa6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 14:28:56 +0300
Subject: fix for extensions without style.css
---
modules/ui.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 29986124..d8d52db1 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1639,6 +1639,9 @@ Requested path was: {f}
css = ""
for cssfile in modules.scripts.list_files_with_name("style.css"):
+ if not os.path.isfile(cssfile):
+ continue
+
with open(cssfile, "r", encoding="utf8") as file:
css += file.read() + "\n"
--
cgit v1.2.3
From 7fd90128eb6d1820045bfe2c2c1269661023a712 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 14:48:43 +0300
Subject: added a guard for hypernet training that will stop early if weights
are getting no gradients
---
modules/hypernetworks/hypernetwork.py | 11 +++++++++++
1 file changed, 11 insertions(+)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 47d91ea5..46039a49 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -310,6 +310,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ steps_without_grad = 0
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
@@ -332,8 +334,17 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item()
optimizer.zero_grad()
+ weights[0].grad = None
loss.backward()
+
+ if weights[0].grad is None:
+ steps_without_grad += 1
+ else:
+ steps_without_grad = 0
+ assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
+
optimizer.step()
+
mean_loss = losses.mean()
if torch.isnan(mean_loss):
raise RuntimeError("Loss diverged.")
--
cgit v1.2.3
From fccba4729db341a299db3343e3264fecd9459a07 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Sat, 22 Oct 2022 12:02:41 +0000
Subject: add an option to avoid dying relu
---
modules/hypernetworks/hypernetwork.py | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b7a04038..3132a56c 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -32,7 +32,6 @@ class HypernetworkModule(torch.nn.Module):
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
- assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
linears = []
for i in range(len(layer_structure) - 1):
@@ -43,12 +42,13 @@ class HypernetworkModule(torch.nn.Module):
# Add an activation func
if activation_func == "linear" or activation_func is None:
pass
+ # If ReLU, Skip adding it to the first layer to avoid dying ReLU
+ elif activation_func == "relu" and i < 1:
+ pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
- raise RuntimeError(
- "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
- )
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
# Add dropout
if use_dropout:
@@ -166,8 +166,8 @@ class Hypernetwork:
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
- HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
)
self.name = state_dict.get('name', self.name)
--
cgit v1.2.3
From 7912acef725832debef58c4c7bf8ec22fb446c0b Mon Sep 17 00:00:00 2001
From: discus0434
Date: Sat, 22 Oct 2022 13:00:44 +0000
Subject: small fix
---
modules/hypernetworks/hypernetwork.py | 12 +++++-------
modules/ui.py | 1 -
2 files changed, 5 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3132a56c..7d12e0ff 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -42,22 +42,20 @@ class HypernetworkModule(torch.nn.Module):
# Add an activation func
if activation_func == "linear" or activation_func is None:
pass
- # If ReLU, Skip adding it to the first layer to avoid dying ReLU
- elif activation_func == "relu" and i < 1:
- pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
- # Add dropout
- if use_dropout:
- linears.append(torch.nn.Dropout(p=0.3))
-
# Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
+ # Add dropout
+ if use_dropout:
+ p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2
+ linears.append(torch.nn.Dropout(p=p))
+
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
diff --git a/modules/ui.py b/modules/ui.py
index cd118552..eca887ca 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1244,7 +1244,6 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row():
with gr.Column(scale=3):
--
cgit v1.2.3
From 6a4fa73a38935a18779ce1809892730fd1572bee Mon Sep 17 00:00:00 2001
From: discus0434
Date: Sat, 22 Oct 2022 13:44:39 +0000
Subject: small fix
---
modules/hypernetworks/hypernetwork.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3372aae2..3bc71ee5 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -51,10 +51,9 @@ class HypernetworkModule(torch.nn.Module):
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
- # Add dropout
- if use_dropout:
- p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2
- linears.append(torch.nn.Dropout(p=p))
+ # Add dropout expect last layer
+ if use_dropout and i < len(layer_structure) - 3:
+ linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
--
cgit v1.2.3
From d37cfffd537cd29309afbcb192c4f979995c6a34 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 19:18:56 +0300
Subject: added callback for creating new settings in extensions
---
modules/script_callbacks.py | 11 +++++++++++
modules/shared.py | 19 +++++++++++++++++--
modules/ui.py | 6 +++++-
3 files changed, 33 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 866b7acd..1270e50f 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,6 +1,7 @@
callbacks_model_loaded = []
callbacks_ui_tabs = []
+callbacks_ui_settings = []
def clear_callbacks():
@@ -22,6 +23,11 @@ def ui_tabs_callback():
return res
+def ui_settings_callback():
+ for callback in callbacks_ui_settings:
+ callback()
+
+
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
@@ -40,3 +46,8 @@ def on_ui_tabs(callback):
"""
callbacks_ui_tabs.append(callback)
+
+def on_ui_settings(callback):
+ """register a function to be called before UI settingsare populated; add your settings
+ by using shared.opts.add_option(shared.OptionInfo(...)) """
+ callbacks_ui_settings.append(callback)
diff --git a/modules/shared.py b/modules/shared.py
index 5d83971e..d9cb65ef 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -165,13 +165,13 @@ def realesrgan_models_names():
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
- self.section = None
+ self.section = section
self.refresh = refresh
@@ -327,6 +327,7 @@ options_templates.update(options_section(('images-history', "Images Browser"), {
}))
+
class Options:
data = None
data_labels = options_templates
@@ -389,6 +390,20 @@ class Options:
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d)
+ def add_option(self, key, info):
+ self.data_labels[key] = info
+
+ def reorder(self):
+ """reorder settings so that all items related to section always go together"""
+
+ section_ids = {}
+ settings_items = self.data_labels.items()
+ for k, item in settings_items:
+ if item.section not in section_ids:
+ section_ids[item.section] = len(section_ids)
+
+ self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+
opts = Options()
if os.path.exists(config_filename):
diff --git a/modules/ui.py b/modules/ui.py
index d8d52db1..2849b111 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1461,6 +1461,9 @@ def create_ui(wrap_gradio_gpu_call):
components = []
component_dict = {}
+ script_callbacks.ui_settings_callback()
+ opts.reorder()
+
def open_folder(f):
if not os.path.exists(f):
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
@@ -1564,7 +1567,8 @@ Requested path was: {f}
previous_section = item.section
- gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value=''.format(item.section[1]))
+ elem_id, text = item.section
+ gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value=''.format(text))
if k in quicksettings_names:
quicksettings_list.append((i, k, item))
--
cgit v1.2.3
From dbc8ab65f6d496459a76547776b656c96ad1350d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 19:19:17 +0300
Subject: typo
---
modules/script_callbacks.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 1270e50f..5bcccd67 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -48,6 +48,6 @@ def on_ui_tabs(callback):
def on_ui_settings(callback):
- """register a function to be called before UI settingsare populated; add your settings
+ """register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
callbacks_ui_settings.append(callback)
--
cgit v1.2.3
From 72383abacdc6a101704a6f73758ce4d0bb68c9d1 Mon Sep 17 00:00:00 2001
From: Greendayle
Date: Sat, 22 Oct 2022 16:50:07 +0200
Subject: Deepdanbooru linux fix
---
modules/deepbooru.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 8914662d..3c34ab7c 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -50,7 +50,8 @@ def create_deepbooru_process(threshold, deepbooru_opts):
the tags.
"""
from modules import shared # prevents circular reference
- shared.deepbooru_process_manager = multiprocessing.Manager()
+ context = multiprocessing.get_context("spawn")
+ shared.deepbooru_process_manager = context.Manager()
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
--
cgit v1.2.3
From e38625011cd4955da4bc67fe95d1d0f4c0c53899 Mon Sep 17 00:00:00 2001
From: Greendayle
Date: Sat, 22 Oct 2022 16:56:52 +0200
Subject: fix part2
---
modules/deepbooru.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 3c34ab7c..8bbc90a4 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -55,7 +55,7 @@ def create_deepbooru_process(threshold, deepbooru_opts):
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
+ shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start()
--
cgit v1.2.3
From 324c7c732dd9afc3d4c397c354797ae5d655b514 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 20:09:37 +0300
Subject: record First pass size as 0x0 for #3328
---
modules/processing.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 372489f7..27c669b0 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -524,6 +524,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
+ self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
+
if self.firstphase_width == 0 or self.firstphase_height == 0:
desired_pixel_count = 512 * 512
actual_pixel_count = self.width * self.height
@@ -545,7 +547,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height
- self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
--
cgit v1.2.3
From 0df94d3fcf9d1fc47c4d39039352a3d5b3380c1f Mon Sep 17 00:00:00 2001
From: MrCheeze
Date: Sat, 22 Oct 2022 12:59:21 -0400
Subject: fix aesthetic gradients doing nothing after loading a different model
---
modules/sd_models.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f9b3063d..49dc3238 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -236,12 +236,11 @@ def load_model(checkpoint_info=None):
sd_model.to(shared.device)
sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
sd_model.eval()
shared.sd_model = sd_model
- script_callbacks.model_loaded_callback(sd_model)
-
print(f"Model loaded.")
return sd_model
@@ -268,6 +267,7 @@ def reload_model_weights(sd_model, info=None):
load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
--
cgit v1.2.3
From 321bacc6a9eaf4a25f31279f288fa752be507a20 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 20:15:12 +0300
Subject: call model_loaded_callback after setting shared.sd_model in case
scripts refer to it using that
---
modules/sd_models.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 49dc3238..e697bb72 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -236,11 +236,12 @@ def load_model(checkpoint_info=None):
sd_model.to(shared.device)
sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
sd_model.eval()
shared.sd_model = sd_model
+ script_callbacks.model_loaded_callback(sd_model)
+
print(f"Model loaded.")
return sd_model
--
cgit v1.2.3
From 24694e5983d0944b901892cb101878e6dec89a20 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sun, 23 Oct 2022 01:57:58 +0900
Subject: Update hypernetwork.py
---
modules/hypernetworks/hypernetwork.py | 55 ++++++++++++++++++++++++++++-------
1 file changed, 44 insertions(+), 11 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3bc71ee5..81132be4 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
+from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
@@ -268,6 +269,32 @@ def stack_conds(conds):
return torch.stack(conds)
+def log_statistics(loss_info:dict, key, value):
+ if key not in loss_info:
+ loss_info[key] = [value]
+ else:
+ loss_info[key].append(value)
+ if len(loss_info) > 1024:
+ loss_info.pop(0)
+
+
+def statistics(data):
+ total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
+ recent_data = data[-32:]
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})"
+ return total_information, recent_information
+
+
+def report_statistics(loss_info:dict):
+ keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
+ for key in keys:
+ info, recent = statistics(loss_info[key])
+ print("Loss statistics for file " + key)
+ print(info)
+ print(recent)
+
+
+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -310,7 +337,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
for weight in weights:
weight.requires_grad = True
- losses = torch.zeros((32,))
+ size = len(ds.indexes)
+ loss_dict = {}
+ losses = torch.zeros((size,))
+ previous_mean_loss = 0
+ print("Mean loss of {} elements".format(size))
last_saved_file = ""
last_saved_image = ""
@@ -329,7 +360,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
-
+ if loss_dict and i % size == 0:
+ previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
+
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
@@ -346,7 +379,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
del c
losses[hypernetwork.step % losses.shape[0]] = loss.item()
-
+ for entry in entries:
+ log_statistics(loss_dict, entry.filename, loss.item())
+
optimizer.zero_grad()
weights[0].grad = None
loss.backward()
@@ -359,10 +394,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
optimizer.step()
- mean_loss = losses.mean()
- if torch.isnan(mean_loss):
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
- pbar.set_description(f"loss: {mean_loss:.7f}")
+ pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
@@ -371,7 +405,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
- "loss": f"{mean_loss:.7f}",
+ "loss": f"{previous_mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
})
@@ -420,14 +454,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"""
-Loss: {mean_loss:.7f}
+Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
Last prompt: {html.escape(entries[0].cond_text)}
Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
-
+
+ report_statistics(loss_dict)
checkpoint = sd_models.select_checkpoint()
hypernetwork.sd_checkpoint = checkpoint.hash
@@ -438,5 +473,3 @@ Last saved image: {html.escape(last_saved_image)}
hypernetwork.save(filename)
return hypernetwork, filename
-
-
--
cgit v1.2.3
From 4fdb53c1e9962507fc8336dad9a0fabfe6c418c0 Mon Sep 17 00:00:00 2001
From: Unnoen
Date: Wed, 19 Oct 2022 21:38:10 +1100
Subject: Generate grid preview for progress image
---
modules/sd_samplers.py | 26 +++++++++++++++++++++++++-
modules/shared.py | 1 +
modules/ui.py | 5 ++++-
3 files changed, 30 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index f58a29b9..74a480e5 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -7,7 +7,7 @@ import inspect
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser, devices, processing
+from modules import prompt_parser, devices, processing, images
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -89,6 +89,30 @@ def sample_to_image(samples):
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
+def samples_to_image_grid(samples):
+ progress_images = []
+ for i in range(len(samples)):
+ # Decode the samples individually to reduce VRAM usage at the cost of a bit of speed.
+ x_sample = processing.decode_first_stage(shared.sd_model, samples[i:i+1])[0]
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ progress_images.append(Image.fromarray(x_sample))
+
+ return images.image_grid(progress_images)
+
+def samples_to_image_grid_combined(samples):
+ progress_images = []
+ # Decode all samples at once to increase speed at the cost of VRAM usage.
+ x_samples = processing.decode_first_stage(shared.sd_model, samples)
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for x_sample in x_samples:
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ progress_images.append(Image.fromarray(x_sample))
+
+ return images.image_grid(progress_images)
def store_latent(decoded):
state.current_latent = decoded
diff --git a/modules/shared.py b/modules/shared.py
index d9cb65ef..95d6e225 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -294,6 +294,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
+ "progress_decode_combined": OptionInfo(False, "Decode all progress images at once. (Slighty speeds up progress generation but consumes significantly more VRAM with large batches.)"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
diff --git a/modules/ui.py b/modules/ui.py
index 56c233ab..de0abc7e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -318,7 +318,10 @@ def check_progress_call(id_part):
if shared.parallel_processing_allowed:
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
+ if opts.progress_decode_combined:
+ shared.state.current_image = modules.sd_samplers.samples_to_image_grid_combined(shared.state.current_latent)
+ else:
+ shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
image = shared.state.current_image
--
cgit v1.2.3
From d213d6ca6f90094cb45c11e2f3cb37d25a8d1f94 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 20:48:13 +0300
Subject: removed the option to use 2x more memory when generating previews
added an option to always only show one image in previews removed duplicate
code
---
modules/sd_samplers.py | 35 ++++++++++-------------------------
modules/shared.py | 2 +-
modules/ui.py | 6 +++---
3 files changed, 14 insertions(+), 29 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 74a480e5..0b408a70 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -71,6 +71,7 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
@@ -82,37 +83,21 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-def sample_to_image(samples):
- x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
+def single_sample_to_image(sample):
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
+
+def sample_to_image(samples):
+ return single_sample_to_image(samples[0])
+
+
def samples_to_image_grid(samples):
- progress_images = []
- for i in range(len(samples)):
- # Decode the samples individually to reduce VRAM usage at the cost of a bit of speed.
- x_sample = processing.decode_first_stage(shared.sd_model, samples[i:i+1])[0]
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- progress_images.append(Image.fromarray(x_sample))
-
- return images.image_grid(progress_images)
-
-def samples_to_image_grid_combined(samples):
- progress_images = []
- # Decode all samples at once to increase speed at the cost of VRAM usage.
- x_samples = processing.decode_first_stage(shared.sd_model, samples)
- x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
-
- for x_sample in x_samples:
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- progress_images.append(Image.fromarray(x_sample))
-
- return images.image_grid(progress_images)
+ return images.image_grid([single_sample_to_image(sample) for sample in samples])
+
def store_latent(decoded):
state.current_latent = decoded
diff --git a/modules/shared.py b/modules/shared.py
index 95d6e225..25bfc895 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -294,7 +294,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
- "progress_decode_combined": OptionInfo(False, "Decode all progress images at once. (Slighty speeds up progress generation but consumes significantly more VRAM with large batches.)"),
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
diff --git a/modules/ui.py b/modules/ui.py
index de0abc7e..ffa14cac 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -318,10 +318,10 @@ def check_progress_call(id_part):
if shared.parallel_processing_allowed:
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- if opts.progress_decode_combined:
- shared.state.current_image = modules.sd_samplers.samples_to_image_grid_combined(shared.state.current_latent)
- else:
+ if opts.show_progress_grid:
shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
+ else:
+ shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
image = shared.state.current_image
--
cgit v1.2.3
From be748e8b086bd9834d08bdd9160649a5e7700af7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 22:05:22 +0300
Subject: add --freeze-settings commandline argument to disable changing
settings
---
modules/shared.py | 1 +
modules/ui.py | 11 +++++++++--
2 files changed, 10 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 25bfc895..b55371d3 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -64,6 +64,7 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
+parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
diff --git a/modules/ui.py b/modules/ui.py
index ffa14cac..2311572c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -580,6 +580,9 @@ def apply_setting(key, value):
if value is None:
return gr.update()
+ if shared.cmd_opts.freeze_settings:
+ return gr.update()
+
# dont allow model to be swapped when model hash exists in prompt
if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
return gr.update()
@@ -1501,6 +1504,8 @@ Requested path was: {f}
def run_settings(*args):
changed = 0
+ assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
+
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
@@ -1530,6 +1535,8 @@ Requested path was: {f}
return f'{changed} settings changed.', opts.dumpjson()
def run_settings_single(value, key):
+ assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
+
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
@@ -1582,7 +1589,7 @@ Requested path was: {f}
elem_id, text = item.section
gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value=''.format(text))
- if k in quicksettings_names:
+ if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
else:
@@ -1615,7 +1622,7 @@ Requested path was: {f}
def reload_scripts():
modules.scripts.reload_script_body_only()
- reload_javascript() # need to refresh the html page
+ reload_javascript() # need to refresh the html page
reload_script_bodies.click(
fn=reload_scripts,
--
cgit v1.2.3
From ca5a9e79dc28eeaa3a161427a82e34703bf15765 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 22:06:54 +0300
Subject: fix for img2img color correction in a batch #3218
---
modules/processing.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 27c669b0..b1877b80 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -403,8 +403,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if (len(prompts) == 0):
break
- #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
- #c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
@@ -716,6 +714,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
+
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
+ self.color_corrections = self.color_corrections * self.batch_size
+
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
--
cgit v1.2.3
From 48dbf99e84045ee7af55bc5b1b86492a240e631e Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sun, 23 Oct 2022 04:17:16 +0900
Subject: Allow tracking real-time loss
Someone had 6000 images in their dataset, and it was shown as 0, which was confusing.
This will allow tracking real time dataset-average loss for registered objects.
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 81132be4..99fd0f8f 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -360,7 +360,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
- if loss_dict and i % size == 0:
+ if len(loss_dict) > 0:
previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
scheduler.apply(optimizer, hypernetwork.step)
--
cgit v1.2.3
From ce42879438bf2dbd76b5b346be656292e42ffb2b Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Sat, 22 Oct 2022 14:53:37 -0500
Subject: fix js func signature and not forget to initialize confirmation var
to prevent exception upon cancelling confirmation
---
modules/ui.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 25aeba3b..e58f040e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -429,10 +429,12 @@ def create_seed_inputs():
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
-def clear_prompt(_prompt, _prompt_neg, confirmed, _token_counter):
+def clear_prompt(prompt, _prompt_neg, confirmed, _token_counter):
"""Given confirmation from a user on the client-side, go ahead with clearing prompt"""
if confirmed:
return ["", "", confirmed, update_token_counter("", 1)]
+ else:
+ return [prompt, _prompt_neg, confirmed, _token_counter]
def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter):
--
cgit v1.2.3
From 1b4d04737ac513cbd55958bb60a4f85166f3484b Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sat, 22 Oct 2022 20:13:16 -0300
Subject: Remove unused imports
---
modules/api/api.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 5b0c934e..a5136b4b 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,11 +1,9 @@
from modules.api.processing import StableDiffusionProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
from modules.sd_samplers import all_samplers
-from modules.extras import run_pnginfo
import modules.shared as shared
import uvicorn
-from fastapi import Body, APIRouter, HTTPException
-from fastapi.responses import JSONResponse
+from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field, Json
import json
import io
@@ -18,7 +16,6 @@ class TextToImageResponse(BaseModel):
parameters: Json
info: Json
-
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
--
cgit v1.2.3
From b02926df1393df311db734af149fb9faf4389cbe Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sat, 22 Oct 2022 20:24:04 -0300
Subject: Moved moodels to their own file and extracted base64 conversion to
its own function
---
modules/api/api.py | 17 ++++++-----------
modules/api/models.py | 8 ++++++++
2 files changed, 14 insertions(+), 11 deletions(-)
create mode 100644 modules/api/models.py
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index a5136b4b..c17d7580 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -4,17 +4,17 @@ from modules.sd_samplers import all_samplers
import modules.shared as shared
import uvicorn
from fastapi import APIRouter, HTTPException
-from pydantic import BaseModel, Field, Json
import json
import io
import base64
+from modules.api.models import *
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
-class TextToImageResponse(BaseModel):
- images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: Json
- info: Json
+def img_to_base64(img):
+ buffer = io.BytesIO()
+ img.save(buffer, format="png")
+ return base64.b64encode(buffer.getvalue())
class Api:
def __init__(self, app, queue_lock):
@@ -41,15 +41,10 @@ class Api:
with self.queue_lock:
processed = process_images(p)
- b64images = []
- for i in processed.images:
- buffer = io.BytesIO()
- i.save(buffer, format="png")
- b64images.append(base64.b64encode(buffer.getvalue()))
+ b64images = list(map(img_to_base64, processed.images))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
-
def img2imgapi(self):
raise NotImplementedError
diff --git a/modules/api/models.py b/modules/api/models.py
new file mode 100644
index 00000000..a7d247d8
--- /dev/null
+++ b/modules/api/models.py
@@ -0,0 +1,8 @@
+from pydantic import BaseModel, Field, Json
+
+class TextToImageResponse(BaseModel):
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: Json
+ info: Json
+
+
\ No newline at end of file
--
cgit v1.2.3
From 28e26c2bef217ae82eb9e980cceb3f67ef22e109 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sat, 22 Oct 2022 23:13:32 -0300
Subject: Add "extra" single image operation
- Separate extra modes into 3 endpoints so the user ddoesn't ahve to
handle so many unused parameters.
- Add response model for codumentation
---
modules/api/api.py | 43 ++++++++++++++++++++++++++++++++++++++-----
modules/api/models.py | 26 +++++++++++++++++++++++++-
2 files changed, 63 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index c17d7580..3b804373 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -8,20 +8,42 @@ import json
import io
import base64
from modules.api.models import *
+from PIL import Image
+from modules.extras import run_extras
+
+def upscaler_to_index(name: str):
+ try:
+ return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
+ except:
+ raise HTTPException(status_code=400, detail="Upscaler not found")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
-def img_to_base64(img):
+def img_to_base64(img: str):
buffer = io.BytesIO()
img.save(buffer, format="png")
return base64.b64encode(buffer.getvalue())
+def base64_to_bytes(base64Img: str):
+ if "," in base64Img:
+ base64Img = base64Img.split(",")[1]
+ return io.BytesIO(base64.b64decode(base64Img))
+
+def base64_to_images(base64Imgs: list[str]):
+ imgs = []
+ for img in base64Imgs:
+ img = Image.open(base64_to_bytes(img))
+ imgs.append(img)
+ return imgs
+
+
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
- self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
+ self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -45,12 +67,23 @@ class Api:
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
-
def img2imgapi(self):
raise NotImplementedError
- def extrasapi(self):
- raise NotImplementedError
+ def extras_single_image_api(self, req: ExtrasSingleImageRequest):
+ upscaler1Index = upscaler_to_index(req.upscaler_1)
+ upscaler2Index = upscaler_to_index(req.upscaler_2)
+
+ reqDict = vars(req)
+ reqDict.pop('upscaler_1')
+ reqDict.pop('upscaler_2')
+
+ reqDict['image'] = base64_to_images([reqDict['image']])[0]
+
+ with self.queue_lock:
+ result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="")
+
+ return ExtrasSingleImageResponse(image="data:image/png;base64,"+img_to_base64(result[0]), html_info_x=result[1], html_info=result[2])
def pnginfoapi(self):
raise NotImplementedError
diff --git a/modules/api/models.py b/modules/api/models.py
index a7d247d8..dcf1ab54 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,8 +1,32 @@
from pydantic import BaseModel, Field, Json
+from typing_extensions import Literal
+from modules.shared import sd_upscalers
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
-
\ No newline at end of file
+class ExtrasBaseRequest(BaseModel):
+ resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
+ show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
+ gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
+ codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
+ codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
+ upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
+ upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
+ upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
+ upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
+ upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+ upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
+ extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
+
+class ExtraBaseResponse(BaseModel):
+ html_info_x: str
+ html_info: str
+
+class ExtrasSingleImageRequest(ExtrasBaseRequest):
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+
+class ExtrasSingleImageResponse(ExtraBaseResponse):
+ image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
\ No newline at end of file
--
cgit v1.2.3
From 1fbfc052eb529d8cf8ce5baf578bcf93d0280c29 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 23 Oct 2022 05:43:34 +0100
Subject: Update hypernetwork.py
---
modules/hypernetworks/hypernetwork.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 99fd0f8f..98a7b62e 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -288,10 +288,13 @@ def statistics(data):
def report_statistics(loss_info:dict):
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
for key in keys:
- info, recent = statistics(loss_info[key])
- print("Loss statistics for file " + key)
- print(info)
- print(recent)
+ try:
+ print("Loss statistics for file " + key)
+ info, recent = statistics(loss_info[key])
+ print(info)
+ print(recent)
+ except Exception as e:
+ print(e)
--
cgit v1.2.3
From a7c213d0f5ebb10722629b8490a5863f9ce6c4fa Mon Sep 17 00:00:00 2001
From: Stephen
Date: Fri, 21 Oct 2022 19:27:40 -0400
Subject: [API][Feature] - Add img2img API endpoint
---
modules/api/api.py | 58 +++++++++++++++++++++++++++++++++++++++++++----
modules/api/processing.py | 11 +++++++--
modules/processing.py | 2 +-
3 files changed, 63 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 5b0c934e..a04f2428 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,5 +1,5 @@
-from modules.api.processing import StableDiffusionProcessingAPI
-from modules.processing import StableDiffusionProcessingTxt2Img, process_images
+from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, Json
import json
import io
import base64
+from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
@@ -18,6 +19,11 @@ class TextToImageResponse(BaseModel):
parameters: Json
info: Json
+class ImageToImageResponse(BaseModel):
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: Json
+ info: Json
+
class Api:
def __init__(self, app, queue_lock):
@@ -25,8 +31,9 @@ class Api:
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
- def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
+ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
@@ -54,8 +61,49 @@ class Api:
- def img2imgapi(self):
- raise NotImplementedError
+ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
+ sampler_index = sampler_to_index(img2imgreq.sampler_index)
+
+ if sampler_index is None:
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
+
+ init_images = img2imgreq.init_images
+ if init_images is None:
+ raise HTTPException(status_code=404, detail="Init image not found")
+
+
+ populate = img2imgreq.copy(update={ # Override __init__ params
+ "sd_model": shared.sd_model,
+ "sampler_index": sampler_index[0],
+ "do_not_save_samples": True,
+ "do_not_save_grid": True
+ }
+ )
+ p = StableDiffusionProcessingImg2Img(**vars(populate))
+
+ imgs = []
+ for img in init_images:
+ # if has a comma, deal with prefix
+ if "," in img:
+ img = img.split(",")[1]
+ # convert base64 to PIL image
+ img = base64.b64decode(img)
+ img = Image.open(io.BytesIO(img))
+ imgs = [img] * p.batch_size
+
+ p.init_images = imgs
+ # Override object param
+ with self.queue_lock:
+ processed = process_images(p)
+
+ b64images = []
+ for i in processed.images:
+ buffer = io.BytesIO()
+ i.save(buffer, format="png")
+ b64images.append(base64.b64encode(buffer.getvalue()))
+
+ return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))
def extrasapi(self):
raise NotImplementedError
diff --git a/modules/api/processing.py b/modules/api/processing.py
index 4c541241..9f1d65c0 100644
--- a/modules/api/processing.py
+++ b/modules/api/processing.py
@@ -1,7 +1,8 @@
+from array import array
from inflection import underscore
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model
-from modules.processing import StableDiffusionProcessingTxt2Img
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
import inspect
@@ -92,8 +93,14 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_mutation = True
return DynamicModel
-StableDiffusionProcessingAPI = PydanticModelGenerator(
+StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}]
+).generate_model()
+
+StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingImg2Img",
+ StableDiffusionProcessingImg2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}]
).generate_model()
\ No newline at end of file
diff --git a/modules/processing.py b/modules/processing.py
index b1877b80..1557ed8c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -623,7 +623,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
+ def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: str=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
--
cgit v1.2.3
From 9e1a8b7734a2881451a2efbf80def011ea41ba49 Mon Sep 17 00:00:00 2001
From: Stephen
Date: Sat, 22 Oct 2022 15:42:00 -0400
Subject: non-implemented mask with any type
---
modules/api/api.py | 4 ++++
modules/api/processing.py | 2 +-
modules/processing.py | 2 +-
3 files changed, 6 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index a04f2428..3df6ff96 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -72,6 +72,10 @@ class Api:
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
+ mask = img2imgreq.mask
+ if mask:
+ raise HTTPException(status_code=400, detail="Mask not supported yet")
+
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
diff --git a/modules/api/processing.py b/modules/api/processing.py
index 9f1d65c0..f551fa35 100644
--- a/modules/api/processing.py
+++ b/modules/api/processing.py
@@ -102,5 +102,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
).generate_model()
\ No newline at end of file
diff --git a/modules/processing.py b/modules/processing.py
index 1557ed8c..ff83023c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -623,7 +623,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: str=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
+ def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
--
cgit v1.2.3
From 5dc0739ecdc1ade8fcf4eb77f2a503ef12489f32 Mon Sep 17 00:00:00 2001
From: Stephen
Date: Sat, 22 Oct 2022 17:10:28 -0400
Subject: working mask
---
modules/api/api.py | 20 ++++++++++++--------
1 file changed, 12 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 3df6ff96..3caa83a4 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -33,6 +33,14 @@ class Api:
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
+ def __base64_to_image(self, base64_string):
+ # if has a comma, deal with prefix
+ if "," in base64_string:
+ base64_string = base64_string.split(",")[1]
+ imgdata = base64.b64decode(base64_string)
+ # convert base64 to PIL image
+ return Image.open(io.BytesIO(imgdata))
+
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -74,26 +82,22 @@ class Api:
mask = img2imgreq.mask
if mask:
- raise HTTPException(status_code=400, detail="Mask not supported yet")
+ mask = self.__base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
- "do_not_save_grid": True
+ "do_not_save_grid": True,
+ "mask": mask
}
)
p = StableDiffusionProcessingImg2Img(**vars(populate))
imgs = []
for img in init_images:
- # if has a comma, deal with prefix
- if "," in img:
- img = img.split(",")[1]
- # convert base64 to PIL image
- img = base64.b64decode(img)
- img = Image.open(io.BytesIO(img))
+ img = self.__base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs
--
cgit v1.2.3
From 1be5933ba21a3badec42b7b2753d626f849b609d Mon Sep 17 00:00:00 2001
From: captin411
Date: Sun, 23 Oct 2022 04:11:07 -0700
Subject: auto cropping now works with non square crops
---
modules/textual_inversion/autocrop.py | 509 ++++++++++++++++++----------------
1 file changed, 269 insertions(+), 240 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index 5a551c25..b2f9241c 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -1,241 +1,270 @@
-import cv2
-from collections import defaultdict
-from math import log, sqrt
-import numpy as np
-from PIL import Image, ImageDraw
-
-GREEN = "#0F0"
-BLUE = "#00F"
-RED = "#F00"
-
-
-def crop_image(im, settings):
- """ Intelligently crop an image to the subject matter """
- if im.height > im.width:
- im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width))
- elif im.width > im.height:
- im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height))
- else:
- im = im.resize((settings.crop_width, settings.crop_height))
-
- if im.height == im.width:
- return im
-
- focus = focal_point(im, settings)
-
- # take the focal point and turn it into crop coordinates that try to center over the focal
- # point but then get adjusted back into the frame
- y_half = int(settings.crop_height / 2)
- x_half = int(settings.crop_width / 2)
-
- x1 = focus.x - x_half
- if x1 < 0:
- x1 = 0
- elif x1 + settings.crop_width > im.width:
- x1 = im.width - settings.crop_width
-
- y1 = focus.y - y_half
- if y1 < 0:
- y1 = 0
- elif y1 + settings.crop_height > im.height:
- y1 = im.height - settings.crop_height
-
- x2 = x1 + settings.crop_width
- y2 = y1 + settings.crop_height
-
- crop = [x1, y1, x2, y2]
-
- if settings.annotate_image:
- d = ImageDraw.Draw(im)
- rect = list(crop)
- rect[2] -= 1
- rect[3] -= 1
- d.rectangle(rect, outline=GREEN)
- if settings.destop_view_image:
- im.show()
-
- return im.crop(tuple(crop))
-
-def focal_point(im, settings):
- corner_points = image_corner_points(im, settings)
- entropy_points = image_entropy_points(im, settings)
- face_points = image_face_points(im, settings)
-
- total_points = len(corner_points) + len(entropy_points) + len(face_points)
-
- corner_weight = settings.corner_points_weight
- entropy_weight = settings.entropy_points_weight
- face_weight = settings.face_points_weight
-
- weight_pref_total = corner_weight + entropy_weight + face_weight
-
- # weight things
- pois = []
- if weight_pref_total == 0 or total_points == 0:
- return pois
-
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ]
- )
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ]
- )
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ]
- )
-
- average_point = poi_average(pois, settings)
-
- if settings.annotate_image:
- d = ImageDraw.Draw(im)
- for f in face_points:
- d.rectangle(f.bounding(f.size), outline=RED)
- for f in entropy_points:
- d.rectangle(f.bounding(30), outline=BLUE)
- for poi in pois:
- w = max(4, 4 * 0.5 * sqrt(poi.weight))
- d.ellipse(poi.bounding(w), fill=BLUE)
- d.ellipse(average_point.bounding(25), outline=GREEN)
-
- return average_point
-
-
-def image_face_points(im, settings):
- np_im = np.array(im)
- gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
-
- tries = [
- [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
- ]
-
- for t in tries:
- # print(t[0])
- classifier = cv2.CascadeClassifier(t[0])
- minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
- try:
- faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
- minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- except:
- continue
-
- if len(faces) > 0:
- rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
- return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects]
- return []
-
-
-def image_corner_points(im, settings):
- grayscale = im.convert("L")
-
- # naive attempt at preventing focal points from collecting at watermarks near the bottom
- gd = ImageDraw.Draw(grayscale)
- gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
-
- np_im = np.array(grayscale)
-
- points = cv2.goodFeaturesToTrack(
- np_im,
- maxCorners=100,
- qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.07,
- useHarrisDetector=False,
- )
-
- if points is None:
- return []
-
- focal_points = []
- for point in points:
- x, y = point.ravel()
- focal_points.append(PointOfInterest(x, y, size=4))
-
- return focal_points
-
-
-def image_entropy_points(im, settings):
- landscape = im.height < im.width
- portrait = im.height > im.width
- if landscape:
- move_idx = [0, 2]
- move_max = im.size[0]
- elif portrait:
- move_idx = [1, 3]
- move_max = im.size[1]
- else:
- return []
-
- e_max = 0
- crop_current = [0, 0, settings.crop_width, settings.crop_height]
- crop_best = crop_current
- while crop_current[move_idx[1]] < move_max:
- crop = im.crop(tuple(crop_current))
- e = image_entropy(crop)
-
- if (e > e_max):
- e_max = e
- crop_best = list(crop_current)
-
- crop_current[move_idx[0]] += 4
- crop_current[move_idx[1]] += 4
-
- x_mid = int(crop_best[0] + settings.crop_width/2)
- y_mid = int(crop_best[1] + settings.crop_height/2)
-
- return [PointOfInterest(x_mid, y_mid, size=25)]
-
-
-def image_entropy(im):
- # greyscale image entropy
- # band = np.asarray(im.convert("L"))
- band = np.asarray(im.convert("1"), dtype=np.uint8)
- hist, _ = np.histogram(band, bins=range(0, 256))
- hist = hist[hist > 0]
- return -np.log2(hist / hist.sum()).sum()
-
-
-def poi_average(pois, settings):
- weight = 0.0
- x = 0.0
- y = 0.0
- for poi in pois:
- weight += poi.weight
- x += poi.x * poi.weight
- y += poi.y * poi.weight
- avg_x = round(x / weight)
- avg_y = round(y / weight)
-
- return PointOfInterest(avg_x, avg_y)
-
-
-class PointOfInterest:
- def __init__(self, x, y, weight=1.0, size=10):
- self.x = x
- self.y = y
- self.weight = weight
- self.size = size
-
- def bounding(self, size):
- return [
- self.x - size//2,
- self.y - size//2,
- self.x + size//2,
- self.y + size//2
- ]
-
-
-class Settings:
- def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False):
- self.crop_width = crop_width
- self.crop_height = crop_height
- self.corner_points_weight = corner_points_weight
- self.entropy_points_weight = entropy_points_weight
- self.face_points_weight = entropy_points_weight
- self.annotate_image = annotate_image
+import cv2
+from collections import defaultdict
+from math import log, sqrt
+import numpy as np
+from PIL import Image, ImageDraw
+
+GREEN = "#0F0"
+BLUE = "#00F"
+RED = "#F00"
+
+
+def crop_image(im, settings):
+ """ Intelligently crop an image to the subject matter """
+
+ scale_by = 1
+ if is_landscape(im.width, im.height):
+ scale_by = settings.crop_height / im.height
+ elif is_portrait(im.width, im.height):
+ scale_by = settings.crop_width / im.width
+ elif is_square(im.width, im.height):
+ if is_square(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_landscape(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_portrait(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_height / im.height
+
+ im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+
+ if im.width == settings.crop_width and im.height == settings.crop_height:
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ rect = [0, 0, im.width, im.height]
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ if settings.destop_view_image:
+ im.show()
+ return im
+
+ focus = focal_point(im, settings)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(settings.crop_height / 2)
+ x_half = int(settings.crop_width / 2)
+
+ x1 = focus.x - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + settings.crop_width > im.width:
+ x1 = im.width - settings.crop_width
+
+ y1 = focus.y - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + settings.crop_height > im.height:
+ y1 = im.height - settings.crop_height
+
+ x2 = x1 + settings.crop_width
+ y2 = y1 + settings.crop_height
+
+ crop = [x1, y1, x2, y2]
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ rect = list(crop)
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ if settings.destop_view_image:
+ im.show()
+
+ return im.crop(tuple(crop))
+
+def focal_point(im, settings):
+ corner_points = image_corner_points(im, settings)
+ entropy_points = image_entropy_points(im, settings)
+ face_points = image_face_points(im, settings)
+
+ total_points = len(corner_points) + len(entropy_points) + len(face_points)
+
+ corner_weight = settings.corner_points_weight
+ entropy_weight = settings.entropy_points_weight
+ face_weight = settings.face_points_weight
+
+ weight_pref_total = corner_weight + entropy_weight + face_weight
+
+ # weight things
+ pois = []
+ if weight_pref_total == 0 or total_points == 0:
+ return pois
+
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ]
+ )
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ]
+ )
+ pois.extend(
+ [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ]
+ )
+
+ average_point = poi_average(pois, settings)
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im)
+ for f in face_points:
+ d.rectangle(f.bounding(f.size), outline=RED)
+ for f in entropy_points:
+ d.rectangle(f.bounding(30), outline=BLUE)
+ for poi in pois:
+ w = max(4, 4 * 0.5 * sqrt(poi.weight))
+ d.ellipse(poi.bounding(w), fill=BLUE)
+ d.ellipse(average_point.bounding(25), outline=GREEN)
+
+ return average_point
+
+
+def image_face_points(im, settings):
+ np_im = np.array(im)
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+
+ tries = [
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
+ ]
+
+ for t in tries:
+ # print(t[0])
+ classifier = cv2.CascadeClassifier(t[0])
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+ try:
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+ except:
+ continue
+
+ if len(faces) > 0:
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects]
+ return []
+
+
+def image_corner_points(im, settings):
+ grayscale = im.convert("L")
+
+ # naive attempt at preventing focal points from collecting at watermarks near the bottom
+ gd = ImageDraw.Draw(grayscale)
+ gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
+
+ np_im = np.array(grayscale)
+
+ points = cv2.goodFeaturesToTrack(
+ np_im,
+ maxCorners=100,
+ qualityLevel=0.04,
+ minDistance=min(grayscale.width, grayscale.height)*0.07,
+ useHarrisDetector=False,
+ )
+
+ if points is None:
+ return []
+
+ focal_points = []
+ for point in points:
+ x, y = point.ravel()
+ focal_points.append(PointOfInterest(x, y, size=4))
+
+ return focal_points
+
+
+def image_entropy_points(im, settings):
+ landscape = im.height < im.width
+ portrait = im.height > im.width
+ if landscape:
+ move_idx = [0, 2]
+ move_max = im.size[0]
+ elif portrait:
+ move_idx = [1, 3]
+ move_max = im.size[1]
+ else:
+ return []
+
+ e_max = 0
+ crop_current = [0, 0, settings.crop_width, settings.crop_height]
+ crop_best = crop_current
+ while crop_current[move_idx[1]] < move_max:
+ crop = im.crop(tuple(crop_current))
+ e = image_entropy(crop)
+
+ if (e > e_max):
+ e_max = e
+ crop_best = list(crop_current)
+
+ crop_current[move_idx[0]] += 4
+ crop_current[move_idx[1]] += 4
+
+ x_mid = int(crop_best[0] + settings.crop_width/2)
+ y_mid = int(crop_best[1] + settings.crop_height/2)
+
+ return [PointOfInterest(x_mid, y_mid, size=25)]
+
+
+def image_entropy(im):
+ # greyscale image entropy
+ # band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
+ hist, _ = np.histogram(band, bins=range(0, 256))
+ hist = hist[hist > 0]
+ return -np.log2(hist / hist.sum()).sum()
+
+
+def poi_average(pois, settings):
+ weight = 0.0
+ x = 0.0
+ y = 0.0
+ for poi in pois:
+ weight += poi.weight
+ x += poi.x * poi.weight
+ y += poi.y * poi.weight
+ avg_x = round(x / weight)
+ avg_y = round(y / weight)
+
+ return PointOfInterest(avg_x, avg_y)
+
+
+def is_landscape(w, h):
+ return w > h
+
+
+def is_portrait(w, h):
+ return h > w
+
+
+def is_square(w, h):
+ return w == h
+
+
+class PointOfInterest:
+ def __init__(self, x, y, weight=1.0, size=10):
+ self.x = x
+ self.y = y
+ self.weight = weight
+ self.size = size
+
+ def bounding(self, size):
+ return [
+ self.x - size//2,
+ self.y - size//2,
+ self.x + size//2,
+ self.y + size//2
+ ]
+
+
+class Settings:
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False):
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.corner_points_weight = corner_points_weight
+ self.entropy_points_weight = entropy_points_weight
+ self.face_points_weight = entropy_points_weight
+ self.annotate_image = annotate_image
self.destop_view_image = False
\ No newline at end of file
--
cgit v1.2.3
From 0523704dade0508bf3ae0c8cb799b1ae332d449b Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sun, 23 Oct 2022 12:27:50 -0300
Subject: Update run_extras to use the temp filename
In batch mode run_extras tries to preserve the original file name of the
images. The problem is that this makes no sense since the user only gets
a list of images in the UI, trying to manually save them shows that this
images have random temp names. Also, trying to keep "orig_name" in the
API is a hassle that adds complexity to the consuming UI since the
client has to use (or emulate) an input (type=file) element in a form.
Using the normal file name not only doesn't change the output and
functionality in the original UI but also helps keep the API simple.
---
modules/extras.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index 22c5a1c1..29ac312e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -33,7 +33,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for img in image_folder:
image = Image.open(img)
imageArr.append(image)
- imageNameArr.append(os.path.splitext(img.orig_name)[0])
+ imageNameArr.append(os.path.splitext(img.name)[0])
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
--
cgit v1.2.3
From 4ff852ffb50859f2eae75375cab94dd790a46886 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sun, 23 Oct 2022 13:07:59 -0300
Subject: Add batch processing "extras" endpoint
---
modules/api/api.py | 25 +++++++++++++++++++++++--
modules/api/models.py | 15 ++++++++++++++-
2 files changed, 37 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 3b804373..528134a8 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -10,6 +10,7 @@ import base64
from modules.api.models import *
from PIL import Image
from modules.extras import run_extras
+from gradio import processing_utils
def upscaler_to_index(name: str):
try:
@@ -44,6 +45,7 @@ class Api:
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
+ self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -78,12 +80,31 @@ class Api:
reqDict.pop('upscaler_1')
reqDict.pop('upscaler_2')
- reqDict['image'] = base64_to_images([reqDict['image']])[0]
+ reqDict['image'] = processing_utils.decode_base64_to_file(reqDict['image'])
with self.queue_lock:
result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="")
- return ExtrasSingleImageResponse(image="data:image/png;base64,"+img_to_base64(result[0]), html_info_x=result[1], html_info=result[2])
+ return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0]), html_info_x=result[1], html_info=result[2])
+
+ def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
+ upscaler1Index = upscaler_to_index(req.upscaler_1)
+ upscaler2Index = upscaler_to_index(req.upscaler_2)
+
+ reqDict = vars(req)
+ reqDict.pop('upscaler_1')
+ reqDict.pop('upscaler_2')
+
+ reqDict['image_folder'] = list(map(processing_utils.decode_base64_to_file, reqDict['imageList']))
+ reqDict.pop('imageList')
+
+ with self.queue_lock:
+ result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=1, image="", input_dir="", output_dir="")
+
+ return ExtrasBatchImagesResponse(images=list(map(processing_utils.encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2])
+
+ def extras_folder_processing_api(self):
+ raise NotImplementedError
def pnginfoapi(self):
raise NotImplementedError
diff --git a/modules/api/models.py b/modules/api/models.py
index dcf1ab54..bbd0ef53 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -29,4 +29,17 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
class ExtrasSingleImageResponse(ExtraBaseResponse):
- image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
\ No newline at end of file
+ image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
+
+class SerializableImage(BaseModel):
+ path: str = Field(title="Path", description="The image's path ()")
+
+class ImageItem(BaseModel):
+ data: str = Field(title="image data")
+ name: str = Field(title="filename")
+
+class ExtrasBatchImagesRequest(ExtrasBaseRequest):
+ imageList: list[str] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+
+class ExtrasBatchImagesResponse(ExtraBaseResponse):
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
\ No newline at end of file
--
cgit v1.2.3
From e0ca4dfbc10e0af8dfc4185e5e758f33fd2f0d81 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sun, 23 Oct 2022 15:13:37 -0300
Subject: Update endpoints to use gradio's own utils functions
---
modules/api/api.py | 75 +++++++++++++++++++++++++--------------------------
modules/api/models.py | 4 +--
2 files changed, 38 insertions(+), 41 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 3f490ce2..3acb1f36 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -20,27 +20,27 @@ def upscaler_to_index(name: str):
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
-def img_to_base64(img: str):
- buffer = io.BytesIO()
- img.save(buffer, format="png")
- return base64.b64encode(buffer.getvalue())
-
-def base64_to_bytes(base64Img: str):
- if "," in base64Img:
- base64Img = base64Img.split(",")[1]
- return io.BytesIO(base64.b64decode(base64Img))
-
-def base64_to_images(base64Imgs: list[str]):
- imgs = []
- for img in base64Imgs:
- img = Image.open(base64_to_bytes(img))
- imgs.append(img)
- return imgs
+# def img_to_base64(img: str):
+# buffer = io.BytesIO()
+# img.save(buffer, format="png")
+# return base64.b64encode(buffer.getvalue())
+
+# def base64_to_bytes(base64Img: str):
+# if "," in base64Img:
+# base64Img = base64Img.split(",")[1]
+# return io.BytesIO(base64.b64decode(base64Img))
+
+# def base64_to_images(base64Imgs: list[str]):
+# imgs = []
+# for img in base64Imgs:
+# img = Image.open(base64_to_bytes(img))
+# imgs.append(img)
+# return imgs
class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: Json
- info: Json
+ parameters: dict
+ info: str
class Api:
@@ -49,17 +49,17 @@ class Api:
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
- self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
- def __base64_to_image(self, base64_string):
- # if has a comma, deal with prefix
- if "," in base64_string:
- base64_string = base64_string.split(",")[1]
- imgdata = base64.b64decode(base64_string)
- # convert base64 to PIL image
- return Image.open(io.BytesIO(imgdata))
+ # def __base64_to_image(self, base64_string):
+ # # if has a comma, deal with prefix
+ # if "," in base64_string:
+ # base64_string = base64_string.split(",")[1]
+ # imgdata = base64.b64decode(base64_string)
+ # # convert base64 to PIL image
+ # return Image.open(io.BytesIO(imgdata))
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -79,11 +79,9 @@ class Api:
with self.queue_lock:
processed = process_images(p)
- b64images = list(map(img_to_base64, processed.images))
-
- return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
-
+ b64images = list(map(processing_utils.encode_pil_to_base64, processed.images))
+ return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.info)
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
@@ -98,7 +96,7 @@ class Api:
mask = img2imgreq.mask
if mask:
- mask = self.__base64_to_image(mask)
+ mask = processing_utils.decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
@@ -113,7 +111,7 @@ class Api:
imgs = []
for img in init_images:
- img = self.__base64_to_image(img)
+ img = processing_utils.decode_base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs
@@ -121,13 +119,12 @@ class Api:
with self.queue_lock:
processed = process_images(p)
- b64images = []
- for i in processed.images:
- buffer = io.BytesIO()
- i.save(buffer, format="png")
- b64images.append(base64.b64encode(buffer.getvalue()))
-
- return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))
+ b64images = list(map(processing_utils.encode_pil_to_base64, processed.images))
+ # for i in processed.images:
+ # buffer = io.BytesIO()
+ # i.save(buffer, format="png")
+ # b64images.append(base64.b64encode(buffer.getvalue()))
+ return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info)
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
upscaler1Index = upscaler_to_index(req.upscaler_1)
diff --git a/modules/api/models.py b/modules/api/models.py
index bbd0ef53..209f8af5 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -4,8 +4,8 @@ from modules.shared import sd_upscalers
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: Json
- info: Json
+ parameters: str
+ info: str
class ExtrasBaseRequest(BaseModel):
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
--
cgit v1.2.3
From 866b36d705a338d299aba385788729d60f7d48c8 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sun, 23 Oct 2022 15:35:49 -0300
Subject: Move processing's models into models.py
It didn't make sense to have two differente files for the same and
"models" is a more descriptive name.
---
modules/api/api.py | 57 ++++-------------------
modules/api/models.py | 112 +++++++++++++++++++++++++++++++++++++++++++++-
modules/api/processing.py | 106 -------------------------------------------
3 files changed, 119 insertions(+), 156 deletions(-)
delete mode 100644 modules/api/processing.py
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 3acb1f36..20e85e82 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,16 +1,11 @@
-from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
-from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.sd_samplers import all_samplers
-import modules.shared as shared
import uvicorn
+from gradio import processing_utils
from fastapi import APIRouter, HTTPException
-import json
-import io
-import base64
+import modules.shared as shared
from modules.api.models import *
-from PIL import Image
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.sd_samplers import all_samplers
from modules.extras import run_extras
-from gradio import processing_utils
def upscaler_to_index(name: str):
try:
@@ -20,29 +15,6 @@ def upscaler_to_index(name: str):
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
-# def img_to_base64(img: str):
-# buffer = io.BytesIO()
-# img.save(buffer, format="png")
-# return base64.b64encode(buffer.getvalue())
-
-# def base64_to_bytes(base64Img: str):
-# if "," in base64Img:
-# base64Img = base64Img.split(",")[1]
-# return io.BytesIO(base64.b64decode(base64Img))
-
-# def base64_to_images(base64Imgs: list[str]):
-# imgs = []
-# for img in base64Imgs:
-# img = Image.open(base64_to_bytes(img))
-# imgs.append(img)
-# return imgs
-
-class ImageToImageResponse(BaseModel):
- images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: dict
- info: str
-
-
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
@@ -51,15 +23,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
- self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
-
- # def __base64_to_image(self, base64_string):
- # # if has a comma, deal with prefix
- # if "," in base64_string:
- # base64_string = base64_string.split(",")[1]
- # imgdata = base64.b64decode(base64_string)
- # # convert base64 to PIL image
- # return Image.open(io.BytesIO(imgdata))
+ self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -81,7 +45,7 @@ class Api:
b64images = list(map(processing_utils.encode_pil_to_base64, processed.images))
- return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.info)
+ return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.info)
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
@@ -120,10 +84,7 @@ class Api:
processed = process_images(p)
b64images = list(map(processing_utils.encode_pil_to_base64, processed.images))
- # for i in processed.images:
- # buffer = io.BytesIO()
- # i.save(buffer, format="png")
- # b64images.append(base64.b64encode(buffer.getvalue()))
+
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info)
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
@@ -134,12 +95,12 @@ class Api:
reqDict.pop('upscaler_1')
reqDict.pop('upscaler_2')
- reqDict['image'] = processing_utils.decode_base64_to_file(reqDict['image'])
+ reqDict['image'] = processing_utils.decode_base64_to_image(reqDict['image'])
with self.queue_lock:
result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="")
- return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0]), html_info_x=result[1], html_info=result[2])
+ return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2])
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
upscaler1Index = upscaler_to_index(req.upscaler_1)
diff --git a/modules/api/models.py b/modules/api/models.py
index 209f8af5..362e6277 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,10 +1,118 @@
-from pydantic import BaseModel, Field, Json
+import inspect
+from pydantic import BaseModel, Field, Json, create_model
+from typing import Any, Optional
from typing_extensions import Literal
+from inflection import underscore
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers
+API_NOT_ALLOWED = [
+ "self",
+ "kwargs",
+ "sd_model",
+ "outpath_samples",
+ "outpath_grids",
+ "sampler_index",
+ "do_not_save_samples",
+ "do_not_save_grid",
+ "extra_generation_params",
+ "overlay_images",
+ "do_not_reload_embeddings",
+ "seed_enable_extras",
+ "prompt_for_display",
+ "sampler_noise_scheduler_override",
+ "ddim_discretize"
+]
+
+class ModelDef(BaseModel):
+ """Assistance Class for Pydantic Dynamic Model Generation"""
+
+ field: str
+ field_alias: str
+ field_type: Any
+ field_value: Any
+
+
+class PydanticModelGenerator:
+ """
+ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
+ source_data is a snapshot of the default values produced by the class
+ params are the names of the actual keys required by __init__
+ """
+
+ def __init__(
+ self,
+ model_name: str = None,
+ class_instance = None,
+ additional_fields = None,
+ ):
+ def field_type_generator(k, v):
+ # field_type = str if not overrides.get(k) else overrides[k]["type"]
+ # print(k, v.annotation, v.default)
+ field_type = v.annotation
+
+ return Optional[field_type]
+
+ def merge_class_params(class_):
+ all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
+ parameters = {}
+ for classes in all_classes:
+ parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
+ return parameters
+
+
+ self._model_name = model_name
+ self._class_data = merge_class_params(class_instance)
+ self._model_def = [
+ ModelDef(
+ field=underscore(k),
+ field_alias=k,
+ field_type=field_type_generator(k, v),
+ field_value=v.default
+ )
+ for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
+ ]
+
+ for fields in additional_fields:
+ self._model_def.append(ModelDef(
+ field=underscore(fields["key"]),
+ field_alias=fields["key"],
+ field_type=fields["type"],
+ field_value=fields["default"]))
+
+ def generate_model(self):
+ """
+ Creates a pydantic BaseModel
+ from the json and overrides provided at initialization
+ """
+ fields = {
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
+ }
+ DynamicModel = create_model(self._model_name, **fields)
+ DynamicModel.__config__.allow_population_by_field_name = True
+ DynamicModel.__config__.allow_mutation = True
+ return DynamicModel
+
+StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingTxt2Img",
+ StableDiffusionProcessingTxt2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}]
+).generate_model()
+
+StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingImg2Img",
+ StableDiffusionProcessingImg2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
+).generate_model()
+
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
- parameters: str
+ parameters: dict
+ info: str
+
+class ImageToImageResponse(BaseModel):
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: dict
info: str
class ExtrasBaseRequest(BaseModel):
diff --git a/modules/api/processing.py b/modules/api/processing.py
deleted file mode 100644
index f551fa35..00000000
--- a/modules/api/processing.py
+++ /dev/null
@@ -1,106 +0,0 @@
-from array import array
-from inflection import underscore
-from typing import Any, Dict, Optional
-from pydantic import BaseModel, Field, create_model
-from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
-import inspect
-
-
-API_NOT_ALLOWED = [
- "self",
- "kwargs",
- "sd_model",
- "outpath_samples",
- "outpath_grids",
- "sampler_index",
- "do_not_save_samples",
- "do_not_save_grid",
- "extra_generation_params",
- "overlay_images",
- "do_not_reload_embeddings",
- "seed_enable_extras",
- "prompt_for_display",
- "sampler_noise_scheduler_override",
- "ddim_discretize"
-]
-
-class ModelDef(BaseModel):
- """Assistance Class for Pydantic Dynamic Model Generation"""
-
- field: str
- field_alias: str
- field_type: Any
- field_value: Any
-
-
-class PydanticModelGenerator:
- """
- Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
- source_data is a snapshot of the default values produced by the class
- params are the names of the actual keys required by __init__
- """
-
- def __init__(
- self,
- model_name: str = None,
- class_instance = None,
- additional_fields = None,
- ):
- def field_type_generator(k, v):
- # field_type = str if not overrides.get(k) else overrides[k]["type"]
- # print(k, v.annotation, v.default)
- field_type = v.annotation
-
- return Optional[field_type]
-
- def merge_class_params(class_):
- all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
- parameters = {}
- for classes in all_classes:
- parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
- return parameters
-
-
- self._model_name = model_name
- self._class_data = merge_class_params(class_instance)
- self._model_def = [
- ModelDef(
- field=underscore(k),
- field_alias=k,
- field_type=field_type_generator(k, v),
- field_value=v.default
- )
- for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
- ]
-
- for fields in additional_fields:
- self._model_def.append(ModelDef(
- field=underscore(fields["key"]),
- field_alias=fields["key"],
- field_type=fields["type"],
- field_value=fields["default"]))
-
- def generate_model(self):
- """
- Creates a pydantic BaseModel
- from the json and overrides provided at initialization
- """
- fields = {
- d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
- }
- DynamicModel = create_model(self._model_name, **fields)
- DynamicModel.__config__.allow_population_by_field_name = True
- DynamicModel.__config__.allow_mutation = True
- return DynamicModel
-
-StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingTxt2Img",
- StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}]
-).generate_model()
-
-StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingImg2Img",
- StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
-).generate_model()
\ No newline at end of file
--
cgit v1.2.3
From 1e625624ba6ab3dfc167f0a5226780bb9b50fb58 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sun, 23 Oct 2022 16:01:16 -0300
Subject: Add folder processing endpoint
Also minor refactor
---
modules/api/api.py | 56 +++++++++++++++++++++++++++------------------------
modules/api/models.py | 6 +++++-
2 files changed, 35 insertions(+), 27 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 20e85e82..7b4fbe29 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,5 +1,5 @@
import uvicorn
-from gradio import processing_utils
+from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, HTTPException
import modules.shared as shared
from modules.api.models import *
@@ -11,10 +11,18 @@ def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except:
- raise HTTPException(status_code=400, detail="Upscaler not found")
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+def setUpscalers(req: dict):
+ reqDict = vars(req)
+ reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
+ reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
+ reqDict.pop('upscaler_1')
+ reqDict.pop('upscaler_2')
+ return reqDict
+
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
@@ -24,6 +32,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
+ self.app.add_api_route("/sdapi/v1/extra-folder-images", self.extras_folder_processing_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -43,7 +52,7 @@ class Api:
with self.queue_lock:
processed = process_images(p)
- b64images = list(map(processing_utils.encode_pil_to_base64, processed.images))
+ b64images = list(map(encode_pil_to_base64, processed.images))
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.info)
@@ -60,7 +69,7 @@ class Api:
mask = img2imgreq.mask
if mask:
- mask = processing_utils.decode_base64_to_image(mask)
+ mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
@@ -75,7 +84,7 @@ class Api:
imgs = []
for img in init_images:
- img = processing_utils.decode_base64_to_image(img)
+ img = decode_base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs
@@ -83,43 +92,38 @@ class Api:
with self.queue_lock:
processed = process_images(p)
- b64images = list(map(processing_utils.encode_pil_to_base64, processed.images))
+ b64images = list(map(encode_pil_to_base64, processed.images))
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info)
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
- upscaler1Index = upscaler_to_index(req.upscaler_1)
- upscaler2Index = upscaler_to_index(req.upscaler_2)
-
- reqDict = vars(req)
- reqDict.pop('upscaler_1')
- reqDict.pop('upscaler_2')
+ reqDict = setUpscalers(req)
- reqDict['image'] = processing_utils.decode_base64_to_image(reqDict['image'])
+ reqDict['image'] = decode_base64_to_image(reqDict['image'])
with self.queue_lock:
- result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="")
+ result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)
- return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2])
+ return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2])
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
- upscaler1Index = upscaler_to_index(req.upscaler_1)
- upscaler2Index = upscaler_to_index(req.upscaler_2)
+ reqDict = setUpscalers(req)
- reqDict = vars(req)
- reqDict.pop('upscaler_1')
- reqDict.pop('upscaler_2')
-
- reqDict['image_folder'] = list(map(processing_utils.decode_base64_to_file, reqDict['imageList']))
+ reqDict['image_folder'] = list(map(decode_base64_to_file, reqDict['imageList']))
reqDict.pop('imageList')
with self.queue_lock:
- result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=1, image="", input_dir="", output_dir="")
+ result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
- return ExtrasBatchImagesResponse(images=list(map(processing_utils.encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2])
+ return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2])
- def extras_folder_processing_api(self):
- raise NotImplementedError
+ def extras_folder_processing_api(self, req:ExtrasFoldersRequest):
+ reqDict = setUpscalers(req)
+
+ with self.queue_lock:
+ result = run_extras(extras_mode=2, image=None, image_folder=None, **reqDict)
+
+ return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2])
def pnginfoapi(self):
raise NotImplementedError
diff --git a/modules/api/models.py b/modules/api/models.py
index 362e6277..6f096807 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -150,4 +150,8 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[str] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: list[str] = Field(title="Images", description="The generated images in base64 format.")
\ No newline at end of file
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
+
+class ExtrasFoldersRequest(ExtrasBaseRequest):
+ input_dir: str = Field(title="Input directory", description="Directory path from where to take the images")
+ output_dir: str = Field(title="Output directory", description="Directory path to put the processsed images into")
--
cgit v1.2.3
From 90f02c75220d187e075203a4e3b450bfba392c4d Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sun, 23 Oct 2022 16:03:30 -0300
Subject: Remove unused field and class
---
modules/api/api.py | 6 +++---
modules/api/models.py | 6 +-----
2 files changed, 4 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 7b4fbe29..799e3701 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -104,7 +104,7 @@ class Api:
with self.queue_lock:
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)
- return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2])
+ return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
@@ -115,7 +115,7 @@ class Api:
with self.queue_lock:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
- return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2])
+ return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def extras_folder_processing_api(self, req:ExtrasFoldersRequest):
reqDict = setUpscalers(req)
@@ -123,7 +123,7 @@ class Api:
with self.queue_lock:
result = run_extras(extras_mode=2, image=None, image_folder=None, **reqDict)
- return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2])
+ return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self):
raise NotImplementedError
diff --git a/modules/api/models.py b/modules/api/models.py
index 6f096807..e461d397 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -130,8 +130,7 @@ class ExtrasBaseRequest(BaseModel):
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
class ExtraBaseResponse(BaseModel):
- html_info_x: str
- html_info: str
+ html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
class ExtrasSingleImageRequest(ExtrasBaseRequest):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
@@ -139,9 +138,6 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest):
class ExtrasSingleImageResponse(ExtraBaseResponse):
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
-class SerializableImage(BaseModel):
- path: str = Field(title="Path", description="The image's path ()")
-
class ImageItem(BaseModel):
data: str = Field(title="image data")
name: str = Field(title="filename")
--
cgit v1.2.3
From 124e44cf1eed1edc68954f63a2a9bc428aabbcec Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 24 Oct 2022 09:51:56 +0800
Subject: remove browser to extension
---
modules/images_history.py | 424 --------------------------------------------
modules/inspiration.py | 193 --------------------
modules/script_callbacks.py | 2 -
modules/shared.py | 15 --
modules/ui.py | 20 +--
5 files changed, 4 insertions(+), 650 deletions(-)
delete mode 100644 modules/images_history.py
delete mode 100644 modules/inspiration.py
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
deleted file mode 100644
index bc5cf11f..00000000
--- a/modules/images_history.py
+++ /dev/null
@@ -1,424 +0,0 @@
-import os
-import shutil
-import time
-import hashlib
-import gradio
-system_bak_path = "webui_log_and_bak"
-custom_tab_name = "custom fold"
-faverate_tab_name = "favorites"
-tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name]
-def is_valid_date(date):
- try:
- time.strptime(date, "%Y%m%d")
- return True
- except:
- return False
-
-def reduplicative_file_move(src, dst):
- def same_name_file(basename, path):
- name, ext = os.path.splitext(basename)
- f_list = os.listdir(path)
- max_num = 0
- for f in f_list:
- if len(f) <= len(basename):
- continue
- f_ext = f[-len(ext):] if len(ext) > 0 else ""
- if f[:len(name)] == name and f_ext == ext:
- if f[len(name)] == "(" and f[-len(ext)-1] == ")":
- number = f[len(name)+1:-len(ext)-1]
- if number.isdigit():
- if int(number) > max_num:
- max_num = int(number)
- return f"{name}({max_num + 1}){ext}"
- name = os.path.basename(src)
- save_name = os.path.join(dst, name)
- if not os.path.exists(save_name):
- shutil.move(src, dst)
- else:
- name = same_name_file(name, dst)
- shutil.move(src, os.path.join(dst, name))
-
-def traverse_all_files(curr_path, image_list, all_type=False):
- try:
- f_list = os.listdir(curr_path)
- except:
- if all_type or (curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt" and curr_path[-4:] != ".csv"):
- image_list.append(curr_path)
- return image_list
- for file in f_list:
- file = os.path.join(curr_path, file)
- if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"):
- pass
- elif os.path.isfile(file) and file[-10:].rfind(".") > 0:
- image_list.append(file)
- else:
- image_list = traverse_all_files(file, image_list)
- return image_list
-
-def auto_sorting(dir_name):
- bak_path = os.path.join(dir_name, system_bak_path)
- if not os.path.exists(bak_path):
- os.mkdir(bak_path)
- log_file = None
- files_list = []
- f_list = os.listdir(dir_name)
- for file in f_list:
- if file == system_bak_path:
- continue
- file_path = os.path.join(dir_name, file)
- if not is_valid_date(file):
- if file[-10:].rfind(".") > 0:
- files_list.append(file_path)
- else:
- files_list = traverse_all_files(file_path, files_list, all_type=True)
-
- for file in files_list:
- date_str = time.strftime("%Y%m%d",time.localtime(os.path.getmtime(file)))
- file_path = os.path.dirname(file)
- hash_path = hashlib.md5(file_path.encode()).hexdigest()
- path = os.path.join(dir_name, date_str, hash_path)
- if not os.path.exists(path):
- os.makedirs(path)
- if log_file is None:
- log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a")
- log_file.write(f"{hash_path},{file_path}\n")
- reduplicative_file_move(file, path)
-
- date_list = []
- f_list = os.listdir(dir_name)
- for f in f_list:
- if is_valid_date(f):
- date_list.append(f)
- elif f == system_bak_path:
- continue
- else:
- try:
- reduplicative_file_move(os.path.join(dir_name, f), bak_path)
- except:
- pass
-
- today = time.strftime("%Y%m%d",time.localtime(time.time()))
- if today not in date_list:
- date_list.append(today)
- return sorted(date_list, reverse=True)
-
-def archive_images(dir_name, date_to):
- filenames = []
- batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num)
- if batch_size <= 0:
- batch_size = opts.images_history_num_per_page * 6
- today = time.strftime("%Y%m%d",time.localtime(time.time()))
- date_to = today if date_to is None or date_to == "" else date_to
- date_to_bak = date_to
- if False: #opts.images_history_reconstruct_directory:
- date_list = auto_sorting(dir_name)
- for date in date_list:
- if date <= date_to:
- path = os.path.join(dir_name, date)
- if date == today and not os.path.exists(path):
- continue
- filenames = traverse_all_files(path, filenames)
- if len(filenames) > batch_size:
- break
- filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file))
- else:
- filenames = traverse_all_files(dir_name, filenames)
- total_num = len(filenames)
- tmparray = [(os.path.getmtime(file), file) for file in filenames ]
- date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
- filenames = []
- date_list = {date_to:None}
- date = time.strftime("%Y%m%d",time.localtime(time.time()))
- for t, f in tmparray:
- date = time.strftime("%Y%m%d",time.localtime(t))
- date_list[date] = None
- if t <= date_stamp:
- filenames.append((t, f ,date))
- date_list = sorted(list(date_list.keys()), reverse=True)
- sort_array = sorted(filenames, key=lambda x:-x[0])
- if len(sort_array) > batch_size:
- date = sort_array[batch_size][2]
- filenames = [x[1] for x in sort_array]
- else:
- date = date_to if len(sort_array) == 0 else sort_array[-1][2]
- filenames = [x[1] for x in sort_array]
- filenames = [x[1] for x in sort_array if x[2]>= date]
- num = len(filenames)
- last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
- date = date[:4] + "/" + date[4:6] + "/" + date[6:8]
- date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8]
- load_info = ""
- load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
- load_info += "
"
- _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
- return (
- date_to,
- load_info,
- filenames,
- 1,
- image_list,
- "",
- "",
- visible_num,
- last_date_from,
- gradio.update(visible=total_num > num)
- )
-
-def delete_image(delete_num, name, filenames, image_index, visible_num):
- if name == "":
- return filenames, delete_num
- else:
- delete_num = int(delete_num)
- visible_num = int(visible_num)
- image_index = int(image_index)
- index = list(filenames).index(name)
- i = 0
- new_file_list = []
- for name in filenames:
- if i >= index and i < index + delete_num:
- if os.path.exists(name):
- if visible_num == image_index:
- new_file_list.append(name)
- i += 1
- continue
- print(f"Delete file {name}")
- os.remove(name)
- visible_num -= 1
- txt_file = os.path.splitext(name)[0] + ".txt"
- if os.path.exists(txt_file):
- os.remove(txt_file)
- else:
- print(f"Not exists file {name}")
- else:
- new_file_list.append(name)
- i += 1
- return new_file_list, 1, visible_num
-
-def save_image(file_name):
- if file_name is not None and os.path.exists(file_name):
- shutil.copy(file_name, opts.outdir_save)
-
-def get_recent_images(page_index, step, filenames):
- page_index = int(page_index)
- num_of_imgs_per_page = int(opts.images_history_num_per_page)
- max_page_index = len(filenames) // num_of_imgs_per_page + 1
- page_index = max_page_index if page_index == -1 else page_index + step
- page_index = 1 if page_index < 1 else page_index
- page_index = max_page_index if page_index > max_page_index else page_index
- idx_frm = (page_index - 1) * num_of_imgs_per_page
- image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page]
- length = len(filenames)
- visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page
- visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
- return page_index, image_list, "", "", visible_num
-
-def loac_batch_click(date_to):
- if date_to is None:
- return time.strftime("%Y%m%d",time.localtime(time.time())), []
- else:
- return None, []
-def forward_click(last_date_from, date_to_recorder):
- if len(date_to_recorder) == 0:
- return None, []
- if last_date_from == date_to_recorder[-1]:
- date_to_recorder = date_to_recorder[:-1]
- if len(date_to_recorder) == 0:
- return None, []
- return date_to_recorder[-1], date_to_recorder[:-1]
-
-def backward_click(last_date_from, date_to_recorder):
- if last_date_from is None or last_date_from == "":
- return time.strftime("%Y%m%d",time.localtime(time.time())), []
- if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]:
- date_to_recorder.append(last_date_from)
- return last_date_from, date_to_recorder
-
-
-def first_page_click(page_index, filenames):
- return get_recent_images(1, 0, filenames)
-
-def end_page_click(page_index, filenames):
- return get_recent_images(-1, 0, filenames)
-
-def prev_page_click(page_index, filenames):
- return get_recent_images(page_index, -1, filenames)
-
-def next_page_click(page_index, filenames):
- return get_recent_images(page_index, 1, filenames)
-
-def page_index_change(page_index, filenames):
- return get_recent_images(page_index, 0, filenames)
-
-def show_image_info(tabname_box, num, page_index, filenames):
- file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
- tm = "" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "
"
- return file, tm, num, file
-
-def enable_page_buttons():
- return gradio.update(visible=True)
-
-def change_dir(img_dir, date_to):
- warning = None
- try:
- if os.path.exists(img_dir):
- try:
- f = os.listdir(img_dir)
- except:
- warning = f"'{img_dir} is not a directory"
- else:
- warning = "The directory is not exist"
- except:
- warning = "The format of the directory is incorrect"
- if warning is None:
- today = time.strftime("%Y%m%d",time.localtime(time.time()))
- return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True)
- else:
- return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False)
-
-def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
- custom_dir = False
- if tabname == "txt2img":
- dir_name = opts.outdir_txt2img_samples
- elif tabname == "img2img":
- dir_name = opts.outdir_img2img_samples
- elif tabname == "extras":
- dir_name = opts.outdir_extras_samples
- elif tabname == faverate_tab_name:
- dir_name = opts.outdir_save
- else:
- custom_dir = True
- dir_name = None
-
- if not custom_dir:
- d = dir_name.split("/")
- dir_name = d[0]
- for p in d[1:]:
- dir_name = os.path.join(dir_name, p)
- if not os.path.exists(dir_name):
- os.makedirs(dir_name)
-
- with gr.Column() as page_panel:
- with gr.Row():
- with gr.Column(scale=1, visible=not custom_dir) as load_batch_box:
- load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True)
- with gr.Column(scale=4):
- with gr.Row():
- img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
- with gr.Row():
- with gr.Column(visible=False, scale=1) as batch_panel:
- with gr.Row():
- forward = gr.Button('Prev batch')
- backward = gr.Button('Next batch')
- with gr.Column(scale=3):
- load_info = gr.HTML(visible=not custom_dir)
- with gr.Row(visible=False) as warning:
- warning_box = gr.Textbox("Message", interactive=False)
-
- with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
- with gr.Column(scale=2):
- with gr.Row(visible=True) as turn_page_buttons:
- #date_to = gr.Dropdown(label="Date to")
- first_page = gr.Button('First Page')
- prev_page = gr.Button('Prev Page')
- page_index = gr.Number(value=1, label="Page Index")
- next_page = gr.Button('Next Page')
- end_page = gr.Button('End Page')
-
- history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num)
- with gr.Row():
- delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
- delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
-
- with gr.Column():
- with gr.Row():
- with gr.Column():
- img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6)
- gr.HTML(" ")
- img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
- img_file_time= gr.HTML()
- with gr.Row():
- if tabname != faverate_tab_name:
- save_btn = gr.Button('Collect')
- pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
- pnginfo_send_to_img2img = gr.Button('Send to img2img')
-
-
- # hiden items
- with gr.Row(visible=False):
- renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
- batch_date_to = gr.Textbox(label="Date to")
- visible_img_num = gr.Number()
- date_to_recorder = gr.State([])
- last_date_from = gr.Textbox()
- tabname_box = gr.Textbox(tabname)
- image_index = gr.Textbox(value=-1)
- set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
- filenames = gr.State()
- all_images_list = gr.State()
- hidden = gr.Image(type="pil")
- info1 = gr.Textbox()
- info2 = gr.Textbox()
-
- img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info])
-
- #change batch
- change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel]
-
- batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output)
- batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
- batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
-
- load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder])
- forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
- backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
-
-
- #delete
- delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
- delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None)
- if tabname != faverate_tab_name:
- save_btn.click(save_image, inputs=[img_file_name], outputs=None)
-
- #turn page
- gallery_inputs = [page_index, filenames]
- gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num]
- first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
- next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
- prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
- end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
- page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
- renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
-
- first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
- next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
- prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
- end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
- page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
- renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
-
- # other funcitons
- set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden])
- img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
- hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
- switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
- switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
-
-
-
-def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict):
- global opts;
- opts = sys_opts
- loads_files_num = int(opts.images_history_num_per_page)
- num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
- if cmp_ops.browse_all_images:
- tabs_list.append(custom_tab_name)
- with gr.Blocks(analytics_enabled=False) as images_history:
- with gr.Tabs() as tabs:
- for tab in tabs_list:
- with gr.Tab(tab):
- with gr.Blocks(analytics_enabled=False) :
- show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
- gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False)
- gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False)
-
- return images_history
diff --git a/modules/inspiration.py b/modules/inspiration.py
deleted file mode 100644
index 29cf8297..00000000
--- a/modules/inspiration.py
+++ /dev/null
@@ -1,193 +0,0 @@
-import os
-import random
-import gradio
-from modules.shared import opts
-inspiration_system_path = os.path.join(opts.inspiration_dir, "system")
-def read_name_list(file, types=None, keyword=None):
- if not os.path.exists(file):
- return []
- ret = []
- f = open(file, "r")
- line = f.readline()
- while len(line) > 0:
- line = line.rstrip("\n")
- if types is not None:
- dirname = os.path.split(line)
- if dirname[0] in types and keyword in dirname[1].lower():
- ret.append(line)
- else:
- ret.append(line)
- line = f.readline()
- return ret
-
-def save_name_list(file, name):
- name_list = read_name_list(file)
- if name not in name_list:
- with open(file, "a") as f:
- f.write(name + "\n")
-
-def get_types_list():
- files = os.listdir(opts.inspiration_dir)
- types = []
- for x in files:
- path = os.path.join(opts.inspiration_dir, x)
- if x[0] == ".":
- continue
- if not os.path.isdir(path):
- continue
- if path == inspiration_system_path:
- continue
- types.append(x)
- return types
-
-def get_inspiration_images(source, types, keyword):
- keyword = keyword.strip(" ").lower()
- get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num)
- if source == "Favorites":
- names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword)
- names = random.sample(names, get_num) if len(names) > get_num else names
- elif source == "Abandoned":
- names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword)
- names = random.sample(names, get_num) if len(names) > get_num else names
- elif source == "Exclude abandoned":
- abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword)
- all_names = []
- for tp in types:
- name_list = os.listdir(os.path.join(opts.inspiration_dir, tp))
- all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()]
-
- if len(all_names) > get_num:
- names = []
- while len(names) < get_num:
- name = random.choice(all_names)
- if name not in abandoned:
- names.append(name)
- else:
- names = all_names
- else:
- all_names = []
- for tp in types:
- name_list = os.listdir(os.path.join(opts.inspiration_dir, tp))
- all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()]
- names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names
- image_list = []
- for a in names:
- image_path = os.path.join(opts.inspiration_dir, a)
- images = os.listdir(image_path)
- if len(images) > 0:
- image_list.append((os.path.join(image_path, random.choice(images)), a))
- else:
- print(image_path)
- return image_list, names
-
-def select_click(index, name_list):
- name = name_list[int(index)]
- path = os.path.join(opts.inspiration_dir, name)
- images = os.listdir(path)
- return name, [os.path.join(path, x) for x in images], ""
-
-def give_up_click(name):
- file = os.path.join(inspiration_system_path, "abandoned.txt")
- save_name_list(file, name)
- return "Added to abandoned list"
-
-def collect_click(name):
- file = os.path.join(inspiration_system_path, "faverites.txt")
- save_name_list(file, name)
- return "Added to faverite list"
-
-def moveout_click(name, source):
- if source == "Abandoned":
- file = os.path.join(inspiration_system_path, "abandoned.txt")
- elif source == "Favorites":
- file = os.path.join(inspiration_system_path, "faverites.txt")
- else:
- return None
- name_list = read_name_list(file)
- os.remove(file)
- with open(file, "a") as f:
- for a in name_list:
- if a != name:
- f.write(a + "\n")
- return f"Moved out {name} from {source} list"
-
-def source_change(source):
- if source in ["Abandoned", "Favorites"]:
- return gradio.update(visible=True), []
- else:
- return gradio.update(visible=False), []
-def add_to_prompt(name, prompt):
- name = os.path.basename(name)
- return prompt + "," + name
-
-def clear_keyword():
- return ""
-
-def ui(gr, opts, txt2img_prompt, img2img_prompt):
- with gr.Blocks(analytics_enabled=False) as inspiration:
- flag = os.path.exists(opts.inspiration_dir)
- if flag:
- types = get_types_list()
- flag = len(types) > 0
- else:
- os.makedirs(opts.inspiration_dir)
- if not flag:
- gr.HTML("""
- To activate inspiration function, you need get "inspiration" images first.
- You can create these images by run "Create inspiration images" script in txt2img page,
you can get the artists or art styles list from here
-
https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
- download these files, and select these files in the "Create inspiration images" script UI
- There about 6000 artists and art styles in these files.
This takes server hours depending on your GPU type and how many pictures you generate for each artist/style
-
I suggest at least four images for each
-
You can also download generated pictures from here:
-
https://huggingface.co/datasets/yfszzx/inspiration
- unzip the file to the project directory of webui
- and restart webui, and enjoy the joy of creation!
- """)
- return inspiration
- if not os.path.exists(inspiration_system_path):
- os.mkdir(inspiration_system_path)
- with gr.Row():
- with gr.Column(scale=2):
- inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto')
- with gr.Column(scale=1):
- types = gr.CheckboxGroup(choices=types, value=types)
- with gr.Row():
- source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source")
- keyword = gr.Textbox("", label="Key word")
- get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button")
- name = gr.Textbox(show_label=False, interactive=False)
- with gr.Row():
- send_to_txt2img = gr.Button('to txt2img')
- send_to_img2img = gr.Button('to img2img')
- collect = gr.Button('Collect')
- give_up = gr.Button("Don't show again")
- moveout = gr.Button("Move out", visible=False)
- warning = gr.HTML()
- style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto')
-
-
-
- with gr.Row(visible=False):
- select_button = gr.Button('set button', elem_id="inspiration_select_button")
- name_list = gr.State()
-
- get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list])
- keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None)
- source.change(source_change, inputs=[source], outputs=[moveout, style_gallery])
- source.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword])
- types.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword])
-
- select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning])
- give_up.click(give_up_click, inputs=[name], outputs=[warning])
- collect.click(collect_click, inputs=[name], outputs=[warning])
- moveout.click(moveout_click, inputs=[name, source], outputs=[warning])
- moveout.click(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None)
-
- send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt])
- send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt])
- send_to_txt2img.click(collect_click, inputs=[name], outputs=[warning])
- send_to_img2img.click(collect_click, inputs=[name], outputs=[warning])
- send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None)
- send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None)
- return inspiration
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 5bcccd67..66666a56 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,4 +1,3 @@
-
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
@@ -16,7 +15,6 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
-
for callback in callbacks_ui_tabs:
res += callback() or []
diff --git a/modules/shared.py b/modules/shared.py
index 0aaaadac..5dfd7927 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -321,21 +321,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
-options_templates.update(options_section(('inspiration', "Inspiration"), {
- "inspiration_dir": OptionInfo("inspiration", "Directory of inspiration", component_args=hide_dirs),
- "inspiration_max_samples": OptionInfo(4, "Maximum number of samples, used to determine which folders to skip when continue running the create script", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
- "inspiration_rows_num": OptionInfo(4, "Rows of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}),
- "inspiration_cols_num": OptionInfo(8, "Columns of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}),
-}))
-
-options_templates.update(options_section(('images-history', "Images Browser"), {
- #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
- "images_history_preload": OptionInfo(False, "Preload images at startup"),
- "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"),
- "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "),
- "images_history_grid_num": OptionInfo(6, "Number of grids in each row"),
-
-}))
class Options:
data = None
diff --git a/modules/ui.py b/modules/ui.py
index a73175f5..fa42712e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -49,14 +49,12 @@ from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
import modules.textual_inversion.ui
import modules.hypernetworks.ui
-import modules.images_history as images_history
-import modules.inspiration as inspiration
-
-
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
+txt2img_paste_fields = []
+img2img_paste_fields = []
if not cmd_opts.share and not cmd_opts.listen:
@@ -1193,16 +1191,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[image],
outputs=[html, generation_info, html2],
)
- #images history
- images_history_switch_dict = {
- "fn": modules.generation_parameters_copypaste.connect_paste,
- "t2i": txt2img_paste_fields,
- "i2i": img2img_paste_fields
- }
-
- browser_interface = images_history.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
- inspiration_interface = inspiration.ui(gr, opts, txt2img_prompt, img2img_prompt)
-
+
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -1651,8 +1640,6 @@ Requested path was: {f}
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
- (inspiration_interface, "Inspiration", "inspiration"),
- (browser_interface , "Image Browser", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
]
@@ -1896,6 +1883,7 @@ def load_javascript(raw_response):
javascript = f''
scripts_list = modules.scripts.list_scripts("javascript", ".js")
+ scripts_list += modules.scripts.list_scripts("scripts", ".js")
for basedir, filename, path in scripts_list:
with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n"
--
cgit v1.2.3
From cef1b89aa2e6c7647db7e93a4cd4ec020da3f2da Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 24 Oct 2022 10:10:33 +0800
Subject: remove browser to extension
---
modules/script_callbacks.py | 2 ++
modules/shared.py | 1 -
modules/ui.py | 2 +-
3 files changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 66666a56..f46d3d9a 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,3 +1,4 @@
+
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
@@ -15,6 +16,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
+
for callback in callbacks_ui_tabs:
res += callback() or []
diff --git a/modules/shared.py b/modules/shared.py
index 5dfd7927..6541e679 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -82,7 +82,6 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
-parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False)
cmd_opts = parser.parse_args()
restricted_opts = [
diff --git a/modules/ui.py b/modules/ui.py
index fa42712e..a32f7259 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1104,7 +1104,7 @@ def create_ui(wrap_gradio_gpu_call):
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
with gr.Group():
- extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers] , value=shared.sd_upscalers[0].name, type="index")
+ extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
--
cgit v1.2.3
From a889c93f23f1e80d0dac4e5ddbc3a26207e8cdf1 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 24 Oct 2022 11:13:16 +0800
Subject: paste_fields add to public
---
modules/ui.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index a32f7259..a73b9ff0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -784,6 +784,7 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ global txt2img_paste_fields
txt2img_paste_fields = [
(txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"),
@@ -1054,6 +1055,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[prompt, negative_prompt, style1, style2],
)
+ global img2img_paste_fields
img2img_paste_fields = [
(img2img_prompt, "Prompt"),
(img2img_negative_prompt, "Negative prompt"),
--
cgit v1.2.3
From 974196932583b96b6b76632052fc0d7e70820bf3 Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Sun, 23 Oct 2022 22:38:42 +0300
Subject: Save properly processed image before color correction
---
modules/processing.py | 33 ++++++++++++++++++---------------
1 file changed, 18 insertions(+), 15 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index ff83023c..15b639e1 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -46,6 +46,20 @@ def apply_color_correction(correction, image):
return image
+def apply_overlay(overlay_exists, overlay, paste_loc, image):
+ if overlay_exists:
+ if paste_loc is not None:
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ image = image.convert('RGBA')
+ image.alpha_composite(overlay)
+ image = image.convert('RGB')
+
+ return image
def get_correct_sampler(p):
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
@@ -446,25 +460,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
image = Image.fromarray(x_sample)
-
+
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ image_without_cc = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image)
+ images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
- if p.overlay_images is not None and i < len(p.overlay_images):
- overlay = p.overlay_images[i]
-
- if p.paste_to is not None:
- x, y, w, h = p.paste_to
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
-
- image = image.convert('RGBA')
- image.alpha_composite(overlay)
- image = image.convert('RGB')
+ image = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image)
if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
--
cgit v1.2.3
From f2cc3f32d5bc8538e95edec54d7dc1b9efdf769a Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Sun, 23 Oct 2022 22:44:46 +0300
Subject: fix whitespaces
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 15b639e1..2a332514 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -460,7 +460,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
image = Image.fromarray(x_sample)
-
+
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image)
--
cgit v1.2.3
From b297cc3324979ec78d69b2d11dd18030dfad7bcc Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sun, 23 Oct 2022 20:06:42 +0900
Subject: Hypernetworks - fix KeyError in statistics caching
Statistics logging has changed to {filename : list[losses]}, so it has to use loss_info[key].pop()
---
modules/hypernetworks/hypernetwork.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 98a7b62e..33827210 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -274,8 +274,8 @@ def log_statistics(loss_info:dict, key, value):
loss_info[key] = [value]
else:
loss_info[key].append(value)
- if len(loss_info) > 1024:
- loss_info.pop(0)
+ if len(loss_info[key]) > 1024:
+ loss_info[key].pop(0)
def statistics(data):
--
cgit v1.2.3
From 40b56c9289bf9458ae5ef3c1990ccea851c6c3e2 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sun, 23 Oct 2022 21:07:07 +0900
Subject: cleanup some code
---
modules/hypernetworks/hypernetwork.py | 14 +++-----------
1 file changed, 3 insertions(+), 11 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 33827210..4072bf54 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
+from collections import defaultdict, deque
from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
@@ -269,15 +270,6 @@ def stack_conds(conds):
return torch.stack(conds)
-def log_statistics(loss_info:dict, key, value):
- if key not in loss_info:
- loss_info[key] = [value]
- else:
- loss_info[key].append(value)
- if len(loss_info[key]) > 1024:
- loss_info[key].pop(0)
-
-
def statistics(data):
total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
recent_data = data[-32:]
@@ -341,7 +333,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
weight.requires_grad = True
size = len(ds.indexes)
- loss_dict = {}
+ loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
@@ -383,7 +375,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item()
for entry in entries:
- log_statistics(loss_dict, entry.filename, loss.item())
+ loss_dict[entry.filename].append(loss.item())
optimizer.zero_grad()
weights[0].grad = None
--
cgit v1.2.3
From 348f89c8d40397c1875cff4a7331018785f9c3b8 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sun, 23 Oct 2022 21:29:53 +0900
Subject: statistics for pbar
---
modules/hypernetworks/hypernetwork.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 4072bf54..48b56029 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -335,6 +335,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
+ previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
@@ -356,7 +357,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
for i, entries in pbar:
hypernetwork.step = i + ititial_step
if len(loss_dict) > 0:
- previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
+ previous_mean_losses = [i[-1] for i in loss_dict.values()]
+ previous_mean_loss = mean(previous_mean_losses)
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
@@ -391,7 +393,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
- pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
+
+ if len(previous_mean_losses) > 1:
+ std = stdev(previous_mean_losses)
+ else:
+ std = 0
+ dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
+ pbar.set_description(dataset_loss_info)
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
--
cgit v1.2.3
From 0d2e1dac407a0e2f5b148d314715f0457b2525b7 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sun, 23 Oct 2022 21:41:39 +0900
Subject: convert deque -> list
I don't feel this being efficient
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 48b56029..fb510fa7 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -282,7 +282,7 @@ def report_statistics(loss_info:dict):
for key in keys:
try:
print("Loss statistics for file " + key)
- info, recent = statistics(loss_info[key])
+ info, recent = statistics(list(loss_info[key]))
print(info)
print(recent)
except Exception as e:
--
cgit v1.2.3
From e9a410b5357612f63528015c5533c2185dcff92e Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sun, 23 Oct 2022 21:47:39 +0900
Subject: check length for variance
---
modules/hypernetworks/hypernetwork.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index fb510fa7..d647ea55 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -271,9 +271,17 @@ def stack_conds(conds):
def statistics(data):
- total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
+ if len(data) < 2:
+ std = 0
+ else:
+ std = stdev(data)
+ total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
recent_data = data[-32:]
- recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})"
+ if len(recent_data) < 2:
+ std = 0
+ else:
+ std = stdev(recent_data)
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
return total_information, recent_information
--
cgit v1.2.3
From 6cbb04f7a5e675cf1f6dfc247aa9c9e8df7dc5ce Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 24 Oct 2022 09:15:26 +0300
Subject: fix #3517 breaking txt2img
---
modules/processing.py | 33 +++++++++++++++++++--------------
1 file changed, 19 insertions(+), 14 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 2a332514..c61bbfbd 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -46,18 +46,23 @@ def apply_color_correction(correction, image):
return image
-def apply_overlay(overlay_exists, overlay, paste_loc, image):
- if overlay_exists:
- if paste_loc is not None:
- x, y, w, h = paste_loc
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
-
- image = image.convert('RGBA')
- image.alpha_composite(overlay)
- image = image.convert('RGB')
+
+def apply_overlay(image, paste_loc, index, overlays):
+ if overlays is None or index >= len(overlays):
+ return image
+
+ overlay = overlays[index]
+
+ if paste_loc is not None:
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ image = image.convert('RGBA')
+ image.alpha_composite(overlay)
+ image = image.convert('RGB')
return image
@@ -463,11 +468,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
- image_without_cc = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image)
+ image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
- image = apply_overlay(p.overlay_images is not None and i < len(p.overlay_images), p.overlay_images[i], p.paste_to, image)
+ image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
--
cgit v1.2.3
From 734986dde3231416813f827242c111da212b2ccb Mon Sep 17 00:00:00 2001
From: Trung Ngo
Date: Mon, 24 Oct 2022 01:17:09 -0500
Subject: add callback after image is saved
---
modules/images.py | 3 ++-
modules/script_callbacks.py | 12 +++++++++++-
2 files changed, 13 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index b9589563..01c60f89 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -12,7 +12,7 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string
-from modules import sd_samplers, shared
+from modules import sd_samplers, shared, script_callbacks
from modules.shared import opts, cmd_opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -467,6 +467,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else:
txt_fullfn = None
+ script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn)
return fullfn, txt_fullfn
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 5bcccd67..5836e4b9 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -2,11 +2,12 @@
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
-
+callbacks_image_saved = []
def clear_callbacks():
callbacks_model_loaded.clear()
callbacks_ui_tabs.clear()
+ callbacks_image_saved.clear()
def model_loaded_callback(sd_model):
@@ -28,6 +29,10 @@ def ui_settings_callback():
callback()
+def image_saved_callback(image, p, fullfn, txt_fullfn):
+ for callback in callbacks_image_saved:
+ callback(image, p, fullfn, txt_fullfn)
+
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
@@ -51,3 +56,8 @@ def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
callbacks_ui_settings.append(callback)
+
+
+def on_save_imaged(callback):
+ """register a function to call after modules.images.save_image is called returning same values, original image and p """
+ callbacks_image_saved.append(callback)
--
cgit v1.2.3
From 876a96f0f9843382ebc8984db3de5d8af0e9ce4c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 24 Oct 2022 09:39:46 +0300
Subject: remove erroneous dir in the extension directory remove loading .js
files from scripts dir (they go into javascript) load scripts after models,
for scripts that depend on loaded models
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index a73b9ff0..03528968 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1885,7 +1885,7 @@ def load_javascript(raw_response):
javascript = f''
scripts_list = modules.scripts.list_scripts("javascript", ".js")
- scripts_list += modules.scripts.list_scripts("scripts", ".js")
+
for basedir, filename, path in scripts_list:
with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n"
--
cgit v1.2.3
From 3be6b29d81408d2adb741bff5b11c80214aa621e Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Mon, 24 Oct 2022 15:14:34 +0900
Subject: indent=4 config.json
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 6541e679..d6ddfe59 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -348,7 +348,7 @@ class Options:
def save(self, filename):
with open(filename, "w", encoding="utf8") as file:
- json.dump(self.data, file)
+ json.dump(self.data, file, indent=4)
def same_type(self, x, y):
if x is None or y is None:
--
cgit v1.2.3
From c5d90628a4058bf49c2fdabf620a24db73407f31 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Sat, 22 Oct 2022 17:16:55 +0900
Subject: move "file_decoration" initialize section
into "if forced_filename is None:"
no need to initialize it if it's not going to be used
---
modules/images.py | 24 ++++++++++++------------
1 file changed, 12 insertions(+), 12 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index b9589563..50a59cff 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -386,18 +386,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
txt_fullfn (`str` or None):
If a text file is saved for this image, this will be its full path. Otherwise None.
'''
- if short_filename or prompt is None or seed is None:
- file_decoration = ""
- elif opts.save_to_dirs:
- file_decoration = opts.samples_filename_pattern or "[seed]"
- else:
- file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
-
- if file_decoration != "":
- file_decoration = "-" + file_decoration.lower()
-
- file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
-
if extension == 'png' and opts.enable_pnginfo and info is not None:
pnginfo = PngImagePlugin.PngInfo()
@@ -419,6 +407,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
os.makedirs(path, exist_ok=True)
if forced_filename is None:
+ if short_filename or prompt is None or seed is None:
+ file_decoration = ""
+ elif opts.save_to_dirs:
+ file_decoration = opts.samples_filename_pattern or "[seed]"
+ else:
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
+
+ if file_decoration != "":
+ file_decoration = "-" + file_decoration.lower()
+
+ file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
+
basecount = get_next_sequence_number(path, basename)
fullfn = "a.png"
fullfn_without_extension = "a"
--
cgit v1.2.3
From 7d4a4db9ea7543c079f4a4a702c2945f4b66cd11 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Sat, 22 Oct 2022 17:48:59 +0900
Subject: modify unnecessary sting assignment as it's going to get overwritten
---
modules/images.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index 50a59cff..cc5066b1 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -420,8 +420,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
basecount = get_next_sequence_number(path, basename)
- fullfn = "a.png"
- fullfn_without_extension = "a"
+ fullfn = None
+ fullfn_without_extension = None
for i in range(500):
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
--
cgit v1.2.3
From 37dd6deafb831a809eaf7ae8d232937a8c7998e7 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Sat, 22 Oct 2022 21:11:15 +0900
Subject: filename pattern [datetime], extended customizable Format and Time
Zone
format:
[datetime]
[datetime]
[datetime]
---
modules/images.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 53 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index cc5066b1..fb0b373d 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,4 +1,5 @@
import datetime
+import pytz
import io
import math
import os
@@ -302,7 +303,7 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
x = x.replace("[date]", datetime.date.today().isoformat())
- x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
+ x = replace_datetime(x)
x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
# Apply [prompt] at last. Because it may contain any replacement word.^M
@@ -353,6 +354,57 @@ def get_next_sequence_number(path, basename):
return result + 1
+def replace_datetime(input_str: str):
+ """
+ Formats sub_string of input_str with formatted datetime with time zone support.
+ accepts sub_string format: [datetime], [datetime], [datetime]
+ case insensitive
+
+ e.g.
+ input: "___[Datetime<%Y_%m_%d %H-%M-%S>]___"
+ return: "___2022_10_22 20-40-14___"
+
+ handles invalid Formats and Time Zones
+
+ time format reference:
+ https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
+
+ valid time zones
+ print(pytz.all_timezones)
+ https://pytz.sourceforge.net/
+ """
+ default_time_format = '%Y%m%d%H%M%S'
+ time_now = datetime.datetime.now()
+
+ # match all datetime to be replace
+ match_itr = re.finditer(r'\[datetime(?:<([^>]*)>(?:<([^>]*)>)?)?]', input_str, re.IGNORECASE)
+ for match in reversed(list(match_itr)):
+ # extract format
+ time_format = match.group(1)
+ if time_format == '':
+ # if time_format is blank use default YYYYMMDDHHMMSS
+ time_format = default_time_format
+
+ # extract timezone
+ try:
+ time_zone = pytz.timezone(match.group(2))
+ except pytz.exceptions.UnknownTimeZoneError as _:
+ # if no time_zone or invalid, use system time
+ time_zone = None
+
+ # generate time string
+ time_zone_time = time_now.astimezone(time_zone)
+ try:
+ formatted_time = time_zone_time.strftime(time_format)
+
+ except (ValueError, TypeError) as _:
+ # if format error then use default_time_format
+ formatted_time = time_zone_time.strftime(default_time_format)
+
+ input_str = input_str[:match.start()] + formatted_time + input_str[match.end():]
+ return input_str
+
+
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
'''Save an image.
--
cgit v1.2.3
From 480d8e764611d7779bf8e0245640a6b84af626f9 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Sat, 22 Oct 2022 22:26:57 +0900
Subject: replace "srt.replace()" in apply_filename_pattern() with equivalent
re.sub()
the file_decoration passed into apply_filename_pattern() is formatted to lowercase to increase compatibility
the use of case sensitive srt.replace()
but because the newly implemented "time format" is case sensitive
the lowercasing the file_decoration will cause time format to be broken
in order to resolve this issue
I decided to replace every srt.replace() and in if "str" in x to regular expression (case insensitive) equivalent
---
modules/images.py | 35 +++++++++++++++++------------------
1 file changed, 17 insertions(+), 18 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index fb0b373d..55d8ab5e 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -291,25 +291,24 @@ def apply_filename_pattern(x, p, seed, prompt):
max_prompt_words = opts.directories_max_prompt_words
if seed is not None:
- x = x.replace("[seed]", str(seed))
+ x = re.sub(r'\[seed]', str(seed), x, flags=re.IGNORECASE)
if p is not None:
- x = x.replace("[steps]", str(p.steps))
- x = x.replace("[cfg]", str(p.cfg_scale))
- x = x.replace("[width]", str(p.width))
- x = x.replace("[height]", str(p.height))
- x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
- x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
-
- x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
- x = x.replace("[date]", datetime.date.today().isoformat())
+ x = re.sub(r'\[steps]', str(p.steps), x, flags=re.IGNORECASE)
+ x = re.sub(r'\[cfg]', str(p.cfg_scale), x, flags=re.IGNORECASE)
+ x = re.sub(r'\[width]', str(p.width), x, flags=re.IGNORECASE)
+ x = re.sub(r'\[height]', str(p.height), x, flags=re.IGNORECASE)
+ x = re.sub(r'\[styles]', sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False), x, flags=re.IGNORECASE)
+ x = re.sub(r'\[sampler]', sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False), x, flags=re.IGNORECASE)
+
+ x = re.sub(r'\[model_hash]', getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash), x, flags=re.IGNORECASE)
+ x = re.sub(r'\[date]', datetime.date.today().isoformat(), x, flags=re.IGNORECASE)
x = replace_datetime(x)
- x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
-
+ x = re.sub(r'\[job_timestamp]', getattr(p, "job_timestamp", shared.state.job_timestamp), x, flags=re.IGNORECASE)
# Apply [prompt] at last. Because it may contain any replacement word.^M
if prompt is not None:
- x = x.replace("[prompt]", sanitize_filename_part(prompt))
- if "[prompt_no_styles]" in x:
+ x = re.sub(r'\[prompt]', sanitize_filename_part(prompt), x, flags=re.IGNORECASE)
+ if re.search(r'\[prompt_no_styles]', x, re.IGNORECASE):
prompt_no_style = prompt
for style in shared.prompt_styles.get_style_prompts(p.styles):
if len(style) > 0:
@@ -317,14 +316,14 @@ def apply_filename_pattern(x, p, seed, prompt):
for part in style_parts:
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
- x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
+ x = re.sub(r'\[prompt_no_styles]', sanitize_filename_part(prompt_no_style, replace_spaces=False), x, flags=re.IGNORECASE)
- x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
- if "[prompt_words]" in x:
+ x = re.sub(r'\[prompt_spaces]', sanitize_filename_part(prompt, replace_spaces=False), x, flags=re.IGNORECASE)
+ if re.search(r'\[prompt_words]', x, re.IGNORECASE):
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
if len(words) == 0:
words = ["empty"]
- x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
+ x = re.sub(r'\[prompt_words]', sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False), x, flags=re.IGNORECASE)
if cmd_opts.hide_ui_dir_config:
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
--
cgit v1.2.3
From 00952fb4a8b86b75402cc8276e00f3daadf1f459 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Sat, 22 Oct 2022 22:31:02 +0900
Subject: add sanitize_filename() to datetime
---
modules/images.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index 55d8ab5e..898b065b 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -400,6 +400,7 @@ def replace_datetime(input_str: str):
# if format error then use default_time_format
formatted_time = time_zone_time.strftime(default_time_format)
+ formatted_time = sanitize_filename_part(formatted_time, replace_spaces=False)
input_str = input_str[:match.start()] + formatted_time + input_str[match.end():]
return input_str
--
cgit v1.2.3
From 8f6af4ed651cdb4456e2f3474d05dd3a18086ac2 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Sat, 22 Oct 2022 22:32:26 +0900
Subject: remove lowercasing file_decoration as it is not needed anymore
---
modules/images.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index 898b065b..321439e3 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -467,7 +467,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
if file_decoration != "":
- file_decoration = "-" + file_decoration.lower()
+ file_decoration = "-" + file_decoration
file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
--
cgit v1.2.3
From 5a981310e68253c77f9fe3144d247cfd531f9e48 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Sun, 23 Oct 2022 18:09:21 +0900
Subject: replace_datetime() can now accept a datetime parameter
---
modules/images.py | 16 +++++++++++-----
1 file changed, 11 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index 321439e3..ff00fc52 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -303,7 +303,7 @@ def apply_filename_pattern(x, p, seed, prompt):
x = re.sub(r'\[model_hash]', getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash), x, flags=re.IGNORECASE)
x = re.sub(r'\[date]', datetime.date.today().isoformat(), x, flags=re.IGNORECASE)
- x = replace_datetime(x)
+ x = replace_datetime(x, datetime.datetime.now())
x = re.sub(r'\[job_timestamp]', getattr(p, "job_timestamp", shared.state.job_timestamp), x, flags=re.IGNORECASE)
# Apply [prompt] at last. Because it may contain any replacement word.^M
if prompt is not None:
@@ -353,8 +353,14 @@ def get_next_sequence_number(path, basename):
return result + 1
-def replace_datetime(input_str: str):
+def replace_datetime(input_str: str, time_datetime: datetime.datetime = None):
"""
+ Args:
+ input_str (`str`):
+ the String to be Formatted
+ time_datetime (`datetime.datetime`)
+ the time to be used, if None, use datetime.datetime.now()
+
Formats sub_string of input_str with formatted datetime with time zone support.
accepts sub_string format: [datetime], [datetime], [datetime]
case insensitive
@@ -373,8 +379,8 @@ def replace_datetime(input_str: str):
https://pytz.sourceforge.net/
"""
default_time_format = '%Y%m%d%H%M%S'
- time_now = datetime.datetime.now()
-
+ if time_datetime is None:
+ time_datetime = datetime.datetime.now()
# match all datetime to be replace
match_itr = re.finditer(r'\[datetime(?:<([^>]*)>(?:<([^>]*)>)?)?]', input_str, re.IGNORECASE)
for match in reversed(list(match_itr)):
@@ -392,7 +398,7 @@ def replace_datetime(input_str: str):
time_zone = None
# generate time string
- time_zone_time = time_now.astimezone(time_zone)
+ time_zone_time = time_datetime.astimezone(time_zone)
try:
formatted_time = time_zone_time.strftime(time_format)
--
cgit v1.2.3
From eb007e5884c23fbc38d7e9d1dd3669625270ca27 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Sun, 23 Oct 2022 18:17:39 +0900
Subject: use the same datetime object for [date] and [datetime]
---
modules/images.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index ff00fc52..a9b1330d 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -302,8 +302,9 @@ def apply_filename_pattern(x, p, seed, prompt):
x = re.sub(r'\[sampler]', sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False), x, flags=re.IGNORECASE)
x = re.sub(r'\[model_hash]', getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash), x, flags=re.IGNORECASE)
- x = re.sub(r'\[date]', datetime.date.today().isoformat(), x, flags=re.IGNORECASE)
- x = replace_datetime(x, datetime.datetime.now())
+ current_time = datetime.datetime.now()
+ x = re.sub(r'\[date]', current_time.strftime('%Y-%m-%d'), x, flags=re.IGNORECASE)
+ x = replace_datetime(x, current_time) # replace [datetime], [datetime], [datetime]
x = re.sub(r'\[job_timestamp]', getattr(p, "job_timestamp", shared.state.job_timestamp), x, flags=re.IGNORECASE)
# Apply [prompt] at last. Because it may contain any replacement word.^M
if prompt is not None:
--
cgit v1.2.3
From 994aaadf0861366b9e6f219e1a3c25a233fbb63c Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 24 Oct 2022 16:44:36 +0800
Subject: a strange bug
---
modules/ui.py | 4 ++++
1 file changed, 4 insertions(+)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index a73b9ff0..8c6dc026 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -55,6 +55,7 @@ mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
txt2img_paste_fields = []
img2img_paste_fields = []
+init_img_components = {}
if not cmd_opts.share and not cmd_opts.listen:
@@ -1174,6 +1175,9 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[init_img_with_mask],
)
+ global init_img_components
+ init_img_components = {"img2img":init_img, "inpaint":init_img_with_mask, "extras":extras_image}
+
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
--
cgit v1.2.3
From 8da1bd48bf9d0411cd9ba87b8d9220743cb5807e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 24 Oct 2022 14:03:58 +0300
Subject: add an option to skip adding number to filenames when saving. rework
filename pattern function go through the pattern once and not calculate any
of replacements until they are actually encountered in the pattern.
---
modules/images.py | 250 ++++++++++++++++++++++++++++--------------------------
modules/shared.py | 8 +-
2 files changed, 135 insertions(+), 123 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index a9b1330d..848ede75 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,4 +1,7 @@
import datetime
+import sys
+import traceback
+
import pytz
import io
import math
@@ -274,10 +277,15 @@ invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
+re_pattern = re.compile(r"([^\[\]]+|\[([^]]+)]|[\[\]]*)")
+re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128
def sanitize_filename_part(text, replace_spaces=True):
+ if text is None:
+ return None
+
if replace_spaces:
text = text.replace(' ', '_')
@@ -287,49 +295,103 @@ def sanitize_filename_part(text, replace_spaces=True):
return text
-def apply_filename_pattern(x, p, seed, prompt):
- max_prompt_words = opts.directories_max_prompt_words
-
- if seed is not None:
- x = re.sub(r'\[seed]', str(seed), x, flags=re.IGNORECASE)
-
- if p is not None:
- x = re.sub(r'\[steps]', str(p.steps), x, flags=re.IGNORECASE)
- x = re.sub(r'\[cfg]', str(p.cfg_scale), x, flags=re.IGNORECASE)
- x = re.sub(r'\[width]', str(p.width), x, flags=re.IGNORECASE)
- x = re.sub(r'\[height]', str(p.height), x, flags=re.IGNORECASE)
- x = re.sub(r'\[styles]', sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False), x, flags=re.IGNORECASE)
- x = re.sub(r'\[sampler]', sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False), x, flags=re.IGNORECASE)
-
- x = re.sub(r'\[model_hash]', getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash), x, flags=re.IGNORECASE)
- current_time = datetime.datetime.now()
- x = re.sub(r'\[date]', current_time.strftime('%Y-%m-%d'), x, flags=re.IGNORECASE)
- x = replace_datetime(x, current_time) # replace [datetime], [datetime], [datetime]
- x = re.sub(r'\[job_timestamp]', getattr(p, "job_timestamp", shared.state.job_timestamp), x, flags=re.IGNORECASE)
- # Apply [prompt] at last. Because it may contain any replacement word.^M
- if prompt is not None:
- x = re.sub(r'\[prompt]', sanitize_filename_part(prompt), x, flags=re.IGNORECASE)
- if re.search(r'\[prompt_no_styles]', x, re.IGNORECASE):
- prompt_no_style = prompt
- for style in shared.prompt_styles.get_style_prompts(p.styles):
- if len(style) > 0:
- style_parts = [y for y in style.split("{prompt}")]
- for part in style_parts:
- prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
- prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
- x = re.sub(r'\[prompt_no_styles]', sanitize_filename_part(prompt_no_style, replace_spaces=False), x, flags=re.IGNORECASE)
-
- x = re.sub(r'\[prompt_spaces]', sanitize_filename_part(prompt, replace_spaces=False), x, flags=re.IGNORECASE)
- if re.search(r'\[prompt_words]', x, re.IGNORECASE):
- words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
- if len(words) == 0:
- words = ["empty"]
- x = re.sub(r'\[prompt_words]', sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False), x, flags=re.IGNORECASE)
-
- if cmd_opts.hide_ui_dir_config:
- x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
-
- return x
+class FilenameGenerator:
+ replacements = {
+ 'seed': lambda self: self.seed if self.seed is not None else '',
+ 'steps': lambda self: self.p and self.p.steps,
+ 'cfg': lambda self: self.p and self.p.cfg_scale,
+ 'width': lambda self: self.p and self.p.width,
+ 'height': lambda self: self.p and self.p.height,
+ 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
+ 'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
+ 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
+ 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
+ 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime]
+ 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
+ 'prompt': lambda self: sanitize_filename_part(self.prompt),
+ 'prompt_no_styles': lambda self: self.prompt_no_style(),
+ 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
+ 'prompt_words': lambda self: self.prompt_words(),
+ }
+ default_time_format = '%Y%m%d%H%M%S'
+
+ def __init__(self, p, seed, prompt):
+ self.p = p
+ self.seed = seed
+ self.prompt = prompt
+
+ def prompt_no_style(self):
+ if self.p is None or self.prompt is None:
+ return None
+
+ prompt_no_style = self.prompt
+ for style in shared.prompt_styles.get_style_prompts(self.p.styles):
+ if len(style) > 0:
+ for part in style.split("{prompt}"):
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
+
+ prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
+
+ return sanitize_filename_part(prompt_no_style, replace_spaces=False)
+
+ def prompt_words(self):
+ words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
+ if len(words) == 0:
+ words = ["empty"]
+ return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
+
+ def datetime(self, *args):
+ time_datetime = datetime.datetime.now()
+
+ time_format = args[0] if len(args) > 0 else self.default_time_format
+ time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
+
+ time_zone_time = time_datetime.astimezone(time_zone)
+ try:
+ formatted_time = time_zone_time.strftime(time_format)
+ except (ValueError, TypeError) as _:
+ formatted_time = time_zone_time.strftime(self.default_time_format)
+
+ return sanitize_filename_part(formatted_time, replace_spaces=False)
+
+ def apply(self, x):
+ res = ''
+
+ for m in re_pattern.finditer(x):
+ text, pattern = m.groups()
+
+ if pattern is None:
+ res += text
+ continue
+
+ pattern_args = []
+ while True:
+ m = re_pattern_arg.match(pattern)
+ if m is None:
+ break
+
+ pattern, arg = m.groups()
+ pattern_args.insert(0, arg)
+
+ fun = self.replacements.get(pattern.lower())
+ if fun is not None:
+ try:
+ replacement = fun(self, *pattern_args)
+ except Exception:
+ replacement = None
+ print(f"Error adding [{pattern}] to filename", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ if replacement is None:
+ res += f'[{pattern}]'
+ else:
+ res += str(replacement)
+
+ continue
+
+ res += f'[{pattern}]'
+
+ return res
def get_next_sequence_number(path, basename):
@@ -354,66 +416,8 @@ def get_next_sequence_number(path, basename):
return result + 1
-def replace_datetime(input_str: str, time_datetime: datetime.datetime = None):
- """
- Args:
- input_str (`str`):
- the String to be Formatted
- time_datetime (`datetime.datetime`)
- the time to be used, if None, use datetime.datetime.now()
-
- Formats sub_string of input_str with formatted datetime with time zone support.
- accepts sub_string format: [datetime], [datetime], [datetime]
- case insensitive
-
- e.g.
- input: "___[Datetime<%Y_%m_%d %H-%M-%S>]___"
- return: "___2022_10_22 20-40-14___"
-
- handles invalid Formats and Time Zones
-
- time format reference:
- https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
-
- valid time zones
- print(pytz.all_timezones)
- https://pytz.sourceforge.net/
- """
- default_time_format = '%Y%m%d%H%M%S'
- if time_datetime is None:
- time_datetime = datetime.datetime.now()
- # match all datetime to be replace
- match_itr = re.finditer(r'\[datetime(?:<([^>]*)>(?:<([^>]*)>)?)?]', input_str, re.IGNORECASE)
- for match in reversed(list(match_itr)):
- # extract format
- time_format = match.group(1)
- if time_format == '':
- # if time_format is blank use default YYYYMMDDHHMMSS
- time_format = default_time_format
-
- # extract timezone
- try:
- time_zone = pytz.timezone(match.group(2))
- except pytz.exceptions.UnknownTimeZoneError as _:
- # if no time_zone or invalid, use system time
- time_zone = None
-
- # generate time string
- time_zone_time = time_datetime.astimezone(time_zone)
- try:
- formatted_time = time_zone_time.strftime(time_format)
-
- except (ValueError, TypeError) as _:
- # if format error then use default_time_format
- formatted_time = time_zone_time.strftime(default_time_format)
-
- formatted_time = sanitize_filename_part(formatted_time, replace_spaces=False)
- input_str = input_str[:match.start()] + formatted_time + input_str[match.end():]
- return input_str
-
-
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
- '''Save an image.
+ """Save an image.
Args:
image (`PIL.Image`):
@@ -444,7 +448,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
The full path of the saved imaged.
txt_fullfn (`str` or None):
If a text file is saved for this image, this will be its full path. Otherwise None.
- '''
+ """
+ namegen = FilenameGenerator(p, seed, prompt)
+
if extension == 'png' and opts.enable_pnginfo and info is not None:
pnginfo = PngImagePlugin.PngInfo()
@@ -460,33 +466,37 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
if save_to_dirs:
- dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
+ dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
path = os.path.join(path, dirname)
os.makedirs(path, exist_ok=True)
if forced_filename is None:
- if short_filename or prompt is None or seed is None:
+ if short_filename or seed is None:
file_decoration = ""
- elif opts.save_to_dirs:
- file_decoration = opts.samples_filename_pattern or "[seed]"
else:
- file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
+ file_decoration = opts.samples_filename_pattern or "[seed]"
+
+ add_number = opts.save_images_add_number or file_decoration == ''
- if file_decoration != "":
+ if file_decoration != "" and add_number:
file_decoration = "-" + file_decoration
- file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
-
- basecount = get_next_sequence_number(path, basename)
- fullfn = None
- fullfn_without_extension = None
- for i in range(500):
- fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
- fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
- fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
- if not os.path.exists(fullfn):
- break
+ file_decoration = namegen.apply(file_decoration) + suffix
+
+ if add_number:
+ basecount = get_next_sequence_number(path, basename)
+ fullfn = None
+ fullfn_without_extension = None
+ for i in range(500):
+ fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
+ fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
+ fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
+ if not os.path.exists(fullfn):
+ break
+ else:
+ fullfn = os.path.join(path, f"{file_decoration}.{extension}")
+ fullfn_without_extension = os.path.join(path, file_decoration)
else:
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
fullfn_without_extension = os.path.join(path, forced_filename)
diff --git a/modules/shared.py b/modules/shared.py
index d6ddfe59..76cbb1bd 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -86,6 +86,7 @@ parser.add_argument("--device-id", type=str, help="Select the default CUDA devic
cmd_opts = parser.parse_args()
restricted_opts = [
"samples_filename_pattern",
+ "directories_filename_pattern",
"outdir_samples",
"outdir_txt2img_samples",
"outdir_img2img_samples",
@@ -190,7 +191,8 @@ options_templates = {}
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
"samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'),
- "samples_filename_pattern": OptionInfo("", "Images filename pattern"),
+ "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
+ "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
"grid_save": OptionInfo(True, "Always save all generated image grids"),
"grid_format": OptionInfo('png', 'File format for grids'),
@@ -225,8 +227,8 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
- "directories_filename_pattern": OptionInfo("", "Directory name pattern"),
- "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
+ "directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs),
+ "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
}))
options_templates.update(options_section(('upscaling', "Upscaling"), {
--
cgit v1.2.3
From 2c05e06ea770b013de55636f15aae0b8e60e0590 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 24 Oct 2022 14:11:14 +0300
Subject: rename api/processing to api/models for #3511
---
modules/api/api.py | 2 +-
modules/api/models.py | 106 ++++++++++++++++++++++++++++++++++++++++++++++
modules/api/processing.py | 106 ----------------------------------------------
3 files changed, 107 insertions(+), 107 deletions(-)
create mode 100644 modules/api/models.py
delete mode 100644 modules/api/processing.py
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 3caa83a4..a860a964 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,4 +1,4 @@
-from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
+from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
diff --git a/modules/api/models.py b/modules/api/models.py
new file mode 100644
index 00000000..f551fa35
--- /dev/null
+++ b/modules/api/models.py
@@ -0,0 +1,106 @@
+from array import array
+from inflection import underscore
+from typing import Any, Dict, Optional
+from pydantic import BaseModel, Field, create_model
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
+import inspect
+
+
+API_NOT_ALLOWED = [
+ "self",
+ "kwargs",
+ "sd_model",
+ "outpath_samples",
+ "outpath_grids",
+ "sampler_index",
+ "do_not_save_samples",
+ "do_not_save_grid",
+ "extra_generation_params",
+ "overlay_images",
+ "do_not_reload_embeddings",
+ "seed_enable_extras",
+ "prompt_for_display",
+ "sampler_noise_scheduler_override",
+ "ddim_discretize"
+]
+
+class ModelDef(BaseModel):
+ """Assistance Class for Pydantic Dynamic Model Generation"""
+
+ field: str
+ field_alias: str
+ field_type: Any
+ field_value: Any
+
+
+class PydanticModelGenerator:
+ """
+ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
+ source_data is a snapshot of the default values produced by the class
+ params are the names of the actual keys required by __init__
+ """
+
+ def __init__(
+ self,
+ model_name: str = None,
+ class_instance = None,
+ additional_fields = None,
+ ):
+ def field_type_generator(k, v):
+ # field_type = str if not overrides.get(k) else overrides[k]["type"]
+ # print(k, v.annotation, v.default)
+ field_type = v.annotation
+
+ return Optional[field_type]
+
+ def merge_class_params(class_):
+ all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
+ parameters = {}
+ for classes in all_classes:
+ parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
+ return parameters
+
+
+ self._model_name = model_name
+ self._class_data = merge_class_params(class_instance)
+ self._model_def = [
+ ModelDef(
+ field=underscore(k),
+ field_alias=k,
+ field_type=field_type_generator(k, v),
+ field_value=v.default
+ )
+ for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
+ ]
+
+ for fields in additional_fields:
+ self._model_def.append(ModelDef(
+ field=underscore(fields["key"]),
+ field_alias=fields["key"],
+ field_type=fields["type"],
+ field_value=fields["default"]))
+
+ def generate_model(self):
+ """
+ Creates a pydantic BaseModel
+ from the json and overrides provided at initialization
+ """
+ fields = {
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
+ }
+ DynamicModel = create_model(self._model_name, **fields)
+ DynamicModel.__config__.allow_population_by_field_name = True
+ DynamicModel.__config__.allow_mutation = True
+ return DynamicModel
+
+StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingTxt2Img",
+ StableDiffusionProcessingTxt2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}]
+).generate_model()
+
+StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingImg2Img",
+ StableDiffusionProcessingImg2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
+).generate_model()
\ No newline at end of file
diff --git a/modules/api/processing.py b/modules/api/processing.py
deleted file mode 100644
index f551fa35..00000000
--- a/modules/api/processing.py
+++ /dev/null
@@ -1,106 +0,0 @@
-from array import array
-from inflection import underscore
-from typing import Any, Dict, Optional
-from pydantic import BaseModel, Field, create_model
-from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
-import inspect
-
-
-API_NOT_ALLOWED = [
- "self",
- "kwargs",
- "sd_model",
- "outpath_samples",
- "outpath_grids",
- "sampler_index",
- "do_not_save_samples",
- "do_not_save_grid",
- "extra_generation_params",
- "overlay_images",
- "do_not_reload_embeddings",
- "seed_enable_extras",
- "prompt_for_display",
- "sampler_noise_scheduler_override",
- "ddim_discretize"
-]
-
-class ModelDef(BaseModel):
- """Assistance Class for Pydantic Dynamic Model Generation"""
-
- field: str
- field_alias: str
- field_type: Any
- field_value: Any
-
-
-class PydanticModelGenerator:
- """
- Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
- source_data is a snapshot of the default values produced by the class
- params are the names of the actual keys required by __init__
- """
-
- def __init__(
- self,
- model_name: str = None,
- class_instance = None,
- additional_fields = None,
- ):
- def field_type_generator(k, v):
- # field_type = str if not overrides.get(k) else overrides[k]["type"]
- # print(k, v.annotation, v.default)
- field_type = v.annotation
-
- return Optional[field_type]
-
- def merge_class_params(class_):
- all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
- parameters = {}
- for classes in all_classes:
- parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
- return parameters
-
-
- self._model_name = model_name
- self._class_data = merge_class_params(class_instance)
- self._model_def = [
- ModelDef(
- field=underscore(k),
- field_alias=k,
- field_type=field_type_generator(k, v),
- field_value=v.default
- )
- for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
- ]
-
- for fields in additional_fields:
- self._model_def.append(ModelDef(
- field=underscore(fields["key"]),
- field_alias=fields["key"],
- field_type=fields["type"],
- field_value=fields["default"]))
-
- def generate_model(self):
- """
- Creates a pydantic BaseModel
- from the json and overrides provided at initialization
- """
- fields = {
- d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
- }
- DynamicModel = create_model(self._model_name, **fields)
- DynamicModel.__config__.allow_population_by_field_name = True
- DynamicModel.__config__.allow_mutation = True
- return DynamicModel
-
-StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingTxt2Img",
- StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}]
-).generate_model()
-
-StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingImg2Img",
- StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
-).generate_model()
\ No newline at end of file
--
cgit v1.2.3
From 595dca85af9e26b5d76cd64659a5bdd9da4f2b89 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Mon, 24 Oct 2022 08:32:18 -0300
Subject: Reverse run_extras change
Update serialization on the batch images endpoint
---
modules/api/api.py | 7 ++++++-
modules/api/models.py | 8 ++++----
modules/extras.py | 2 +-
3 files changed, 11 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 799e3701..67b783de 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -109,7 +109,12 @@ class Api:
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
- reqDict['image_folder'] = list(map(decode_base64_to_file, reqDict['imageList']))
+ def prepareFiles(file):
+ file = decode_base64_to_file(file.data, file_path=file.name)
+ file.orig_name = file.name
+ return file
+
+ reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
reqDict.pop('imageList')
with self.queue_lock:
diff --git a/modules/api/models.py b/modules/api/models.py
index e461d397..fca2f991 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -138,12 +138,12 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest):
class ExtrasSingleImageResponse(ExtraBaseResponse):
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
-class ImageItem(BaseModel):
- data: str = Field(title="image data")
- name: str = Field(title="filename")
+class FileData(BaseModel):
+ data: str = Field(title="File data", description="Base64 representation of the file")
+ name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
- imageList: list[str] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+ imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
diff --git a/modules/extras.py b/modules/extras.py
index 29ac312e..22c5a1c1 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -33,7 +33,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for img in image_folder:
image = Image.open(img)
imageArr.append(image)
- imageNameArr.append(os.path.splitext(img.name)[0])
+ imageNameArr.append(os.path.splitext(img.orig_name)[0])
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
--
cgit v1.2.3
From 0c0028a9d3be4fc98bbdf975a6e0daa663bbed2d Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Mon, 24 Oct 2022 21:13:36 +0900
Subject: UnknownTimeZoneError
---
modules/images.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index 848ede75..9a8fe3ed 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -344,7 +344,10 @@ class FilenameGenerator:
time_datetime = datetime.datetime.now()
time_format = args[0] if len(args) > 0 else self.default_time_format
- time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
+ try:
+ time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
+ except pytz.exceptions.UnknownTimeZoneError as _:
+ time_zone = None
time_zone_time = time_datetime.astimezone(time_zone)
try:
--
cgit v1.2.3
From b38370275275bf6e11575000f39c50c6e90b1f7a Mon Sep 17 00:00:00 2001
From: ritosonn
Date: Fri, 21 Oct 2022 23:46:32 +0900
Subject: fix #3145 #3093
---
modules/sd_samplers.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 0b408a70..3670b57d 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -228,7 +228,7 @@ class VanillaStableDiffusionSampler:
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
- samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
+ samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
@@ -429,7 +429,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
--
cgit v1.2.3
From 77a320f406a76425176b8ca4c034c362b6734713 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 24 Oct 2022 17:23:51 +0300
Subject: do not stop execution when script's callback misbehaves and report
which script it was
---
modules/script_callbacks.py | 45 ++++++++++++++++++++++++++++++++++++---------
1 file changed, 36 insertions(+), 9 deletions(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index f46d3d9a..5e9a6d4b 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,4 +1,15 @@
+import sys
+import traceback
+from collections import namedtuple
+import inspect
+
+def report_exception(c, job):
+ print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+
+ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
@@ -10,28 +21,44 @@ def clear_callbacks():
def model_loaded_callback(sd_model):
- for callback in callbacks_model_loaded:
- callback(sd_model)
+ for c in callbacks_model_loaded:
+ try:
+ c.callback(sd_model)
+ except Exception:
+ report_exception(c, 'model_loaded_callback')
def ui_tabs_callback():
res = []
- for callback in callbacks_ui_tabs:
- res += callback() or []
+ for c in callbacks_ui_tabs:
+ try:
+ res += c.callback() or []
+ except Exception:
+ report_exception(c, 'ui_tabs_callback')
return res
def ui_settings_callback():
- for callback in callbacks_ui_settings:
- callback()
+ for c in callbacks_ui_settings:
+ try:
+ c.callback()
+ except Exception:
+ report_exception(c, 'ui_settings_callback')
+
+
+def add_callback(callbacks, fun):
+ stack = [x for x in inspect.stack() if x.filename != __file__]
+ filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+
+ callbacks.append(ScriptCallback(filename, fun))
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
- callbacks_model_loaded.append(callback)
+ add_callback(callbacks_model_loaded, callback)
def on_ui_tabs(callback):
@@ -44,10 +71,10 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
- callbacks_ui_tabs.append(callback)
+ add_callback(callbacks_ui_tabs, callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
- callbacks_ui_settings.append(callback)
+ add_callback(callbacks_ui_settings, callback)
--
cgit v1.2.3
From 4c24347e45776d505937856ab280548d9298f0a8 Mon Sep 17 00:00:00 2001
From: brkirch
Date: Mon, 24 Oct 2022 23:04:50 -0400
Subject: Remove BSRGAN from --use-cpu, add SwinIR
---
modules/devices.py | 2 +-
modules/shared.py | 6 +++---
modules/swinir_model.py | 12 ++++++------
3 files changed, 10 insertions(+), 10 deletions(-)
(limited to 'modules')
diff --git a/modules/devices.py b/modules/devices.py
index dc1f3cdd..033a42d5 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -45,7 +45,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
+device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
diff --git a/modules/shared.py b/modules/shared.py
index 76cbb1bd..308fccce 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -58,7 +58,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -96,8 +96,8 @@ restricted_opts = [
"outdir_save",
]
-devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index baa02e3d..facd262d 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -7,8 +7,8 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
-from modules import modelloader
-from modules.shared import cmd_opts, opts, device
+from modules import modelloader, devices
+from modules.shared import cmd_opts, opts
from modules.swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -42,7 +42,7 @@ class UpscalerSwinIR(Upscaler):
model = self.load_model(model_file)
if model is None:
return img
- model = model.to(device)
+ model = model.to(devices.device_swinir)
img = upscale(img, model)
try:
torch.cuda.empty_cache()
@@ -111,7 +111,7 @@ def upscale(
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
+ img = img.unsqueeze(0).to(devices.device_swinir)
with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
@@ -139,8 +139,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
- W = torch.zeros_like(E, dtype=torch.half, device=device)
+ E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img)
+ W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
--
cgit v1.2.3
From faed465a0b1a7d19669568738c93e04907c10415 Mon Sep 17 00:00:00 2001
From: brkirch
Date: Tue, 25 Oct 2022 02:01:57 -0400
Subject: MPS Upscalers Fix
Get ESRGAN, SCUNet, and SwinIR working correctly on MPS by ensuring memory is contiguous for tensor views before sending to MPS device.
---
modules/devices.py | 4 ++++
modules/esrgan_model.py | 2 +-
modules/scunet_model.py | 3 +--
modules/swinir_model.py | 2 +-
4 files changed, 7 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/devices.py b/modules/devices.py
index 033a42d5..7511e1dc 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -81,3 +81,7 @@ def autocast(disable=False):
return contextlib.nullcontext()
return torch.autocast("cuda")
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
+def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index a49e2258..a13cf6ac 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -190,7 +190,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/scunet_model.py b/modules/scunet_model.py
index 36a996bf..59532274 100644
--- a/modules/scunet_model.py
+++ b/modules/scunet_model.py
@@ -54,9 +54,8 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), device)
- img = img.to(device)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index facd262d..4253b66d 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -111,7 +111,7 @@ def upscale(
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_swinir)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir)
with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
--
cgit v1.2.3
From 91c1e1e6a92061b99c92a5b1d548535907d2ad96 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Tue, 25 Oct 2022 00:46:28 +0900
Subject: fix default filename pattern
---
modules/images.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index 0e044af2..286de2ae 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -477,8 +477,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if forced_filename is None:
if short_filename or seed is None:
file_decoration = ""
- else:
+ elif opts.save_to_dirs:
file_decoration = opts.samples_filename_pattern or "[seed]"
+ else:
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
add_number = opts.save_images_add_number or file_decoration == ''
--
cgit v1.2.3
From ff305acd51cc71c5eea8aee0f537a26a6d1ba2a1 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Tue, 25 Oct 2022 15:33:43 +0800
Subject: some rights for extensions
---
modules/shared.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 76cbb1bd..7b1fadf2 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -82,6 +82,7 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
+parser.add_argument("--administrator", type=str, help="Administrator rights", default=None)
cmd_opts = parser.parse_args()
restricted_opts = [
--
cgit v1.2.3
From 3e15f8e0f5cc87507f77546d92435670644dbd18 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 25 Oct 2022 12:16:17 +0300
Subject: update callbacks code for #3549
---
modules/script_callbacks.py | 22 ++++++++++++++++------
1 file changed, 16 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 9933fa38..dc520abc 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -49,6 +49,14 @@ def ui_settings_callback():
report_exception(c, 'ui_settings_callback')
+def image_saved_callback(image, p, fullfn, txt_fullfn):
+ for c in callbacks_image_saved:
+ try:
+ c.callback(image, p, fullfn, txt_fullfn)
+ except Exception:
+ report_exception(c, 'image_saved_callback')
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -56,9 +64,6 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
-def image_saved_callback(image, p, fullfn, txt_fullfn):
- for callback in callbacks_image_saved:
- callback(image, p, fullfn, txt_fullfn)
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
@@ -82,9 +87,14 @@ def on_ui_tabs(callback):
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
- callbacks_ui_settings.append(callback)
+ add_callback(callbacks_ui_settings, callback)
def on_save_imaged(callback):
- """register a function to call after modules.images.save_image is called returning same values, original image and p """
- callbacks_image_saved.append(callback)
+ """register a function to be called after modules.images.save_image is called.
+ The callback is called with three arguments:
+ - p - procesing object (or a dummy object with same fields if the image is saved using save button)
+ - fullfn - image filename
+ - txt_fullfn - text file with parameters; may be None
+ """
+ add_callback(callbacks_image_saved, callback)
--
cgit v1.2.3
From 9ba439b53313ef78984dd8e39f25b34501188ee2 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Tue, 25 Oct 2022 18:48:07 +0800
Subject: need some rights for extensions
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 7b1fadf2..b5975707 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -82,7 +82,7 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
-parser.add_argument("--administrator", type=str, help="Administrator rights", default=None)
+parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
cmd_opts = parser.parse_args()
restricted_opts = [
--
cgit v1.2.3
From f9549d1cbb3f1d7d1f0fb70375a06e31f9c5dd9d Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Tue, 25 Oct 2022 11:14:12 -0700
Subject: Added option to use unmasked conditioning image.
---
modules/processing.py | 6 +++++-
modules/shared.py | 1 +
2 files changed, 6 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index c61bbfbd..96f56b0d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -768,7 +768,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# Create another latent image, this time with a masked version of the original input.
conditioning_mask = conditioning_mask.to(image.device)
- conditioning_image = image * (1.0 - conditioning_mask)
+
+ conditioning_image = image
+ if shared.opts.inpainting_mask_image:
+ conditioning_image = conditioning_image * (1.0 - conditioning_mask)
+
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
diff --git a/modules/shared.py b/modules/shared.py
index 308fccce..1d0ff1a1 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -320,6 +320,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
+ "inpainting_mask_image": OptionInfo(True, "Mask original image for conditioning used by inpainting model."),
}))
--
cgit v1.2.3
From 605d27687f433c0fefb9025aace7dc94f0ebd454 Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Tue, 25 Oct 2022 12:20:54 -0700
Subject: Added conditioning image masking to xy_grid. Use `True` and `False`
to select values.
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 96f56b0d..23ee5e02 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -770,7 +770,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
conditioning_mask = conditioning_mask.to(image.device)
conditioning_image = image
- if shared.opts.inpainting_mask_image:
+ if getattr(self, "inpainting_mask_image", shared.opts.inpainting_mask_image):
conditioning_image = conditioning_image * (1.0 - conditioning_mask)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
--
cgit v1.2.3
From 3e6c2420c1177e9e79f2b566a5a7795b7416e34a Mon Sep 17 00:00:00 2001
From: captin411
Date: Tue, 25 Oct 2022 13:10:58 -0700
Subject: improve debug markers, fix algo weighting
---
modules/textual_inversion/autocrop.py | 207 +++++++++++++++++++++-------------
1 file changed, 129 insertions(+), 78 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index b2f9241c..caaf18c8 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -1,4 +1,5 @@
import cv2
+import os
from collections import defaultdict
from math import log, sqrt
import numpy as np
@@ -26,19 +27,9 @@ def crop_image(im, settings):
scale_by = settings.crop_height / im.height
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+ im_debug = im.copy()
- if im.width == settings.crop_width and im.height == settings.crop_height:
- if settings.annotate_image:
- d = ImageDraw.Draw(im)
- rect = [0, 0, im.width, im.height]
- rect[2] -= 1
- rect[3] -= 1
- d.rectangle(rect, outline=GREEN)
- if settings.destop_view_image:
- im.show()
- return im
-
- focus = focal_point(im, settings)
+ focus = focal_point(im_debug, settings)
# take the focal point and turn it into crop coordinates that try to center over the focal
# point but then get adjusted back into the frame
@@ -62,89 +53,143 @@ def crop_image(im, settings):
crop = [x1, y1, x2, y2]
+ results = []
+
+ results.append(im.crop(tuple(crop)))
+
if settings.annotate_image:
- d = ImageDraw.Draw(im)
+ d = ImageDraw.Draw(im_debug)
rect = list(crop)
rect[2] -= 1
rect[3] -= 1
d.rectangle(rect, outline=GREEN)
+ results.append(im_debug)
if settings.destop_view_image:
- im.show()
+ im_debug.show()
- return im.crop(tuple(crop))
+ return results
def focal_point(im, settings):
corner_points = image_corner_points(im, settings)
entropy_points = image_entropy_points(im, settings)
face_points = image_face_points(im, settings)
- total_points = len(corner_points) + len(entropy_points) + len(face_points)
-
- corner_weight = settings.corner_points_weight
- entropy_weight = settings.entropy_points_weight
- face_weight = settings.face_points_weight
-
- weight_pref_total = corner_weight + entropy_weight + face_weight
-
- # weight things
pois = []
- if weight_pref_total == 0 or total_points == 0:
- return pois
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ]
- )
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ]
- )
- pois.extend(
- [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ]
- )
+ weight_pref_total = 0
+ if len(corner_points) > 0:
+ weight_pref_total += settings.corner_points_weight
+ if len(entropy_points) > 0:
+ weight_pref_total += settings.entropy_points_weight
+ if len(face_points) > 0:
+ weight_pref_total += settings.face_points_weight
+
+ corner_centroid = None
+ if len(corner_points) > 0:
+ corner_centroid = centroid(corner_points)
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
+ pois.append(corner_centroid)
+
+ entropy_centroid = None
+ if len(entropy_points) > 0:
+ entropy_centroid = centroid(entropy_points)
+ entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
+ pois.append(entropy_centroid)
+
+ face_centroid = None
+ if len(face_points) > 0:
+ face_centroid = centroid(face_points)
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
+ pois.append(face_centroid)
average_point = poi_average(pois, settings)
if settings.annotate_image:
d = ImageDraw.Draw(im)
- for f in face_points:
- d.rectangle(f.bounding(f.size), outline=RED)
- for f in entropy_points:
- d.rectangle(f.bounding(30), outline=BLUE)
- for poi in pois:
- w = max(4, 4 * 0.5 * sqrt(poi.weight))
- d.ellipse(poi.bounding(w), fill=BLUE)
- d.ellipse(average_point.bounding(25), outline=GREEN)
+ max_size = min(im.width, im.height) * 0.07
+ if corner_centroid is not None:
+ color = BLUE
+ box = corner_centroid.bounding(max_size * corner_centroid.weight)
+ d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(corner_points) > 1:
+ for f in corner_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if entropy_centroid is not None:
+ color = "#ff0"
+ box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
+ d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(entropy_points) > 1:
+ for f in entropy_points:
+ d.rectangle(f.bounding(4), outline=color)
+ if face_centroid is not None:
+ color = RED
+ box = face_centroid.bounding(max_size * face_centroid.weight)
+ d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
+ d.ellipse(box, outline=color)
+ if len(face_points) > 1:
+ for f in face_points:
+ d.rectangle(f.bounding(4), outline=color)
+
+ d.ellipse(average_point.bounding(max_size), outline=GREEN)
return average_point
def image_face_points(im, settings):
- np_im = np.array(im)
- gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
-
- tries = [
- [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
- [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
- ]
-
- for t in tries:
- # print(t[0])
- classifier = cv2.CascadeClassifier(t[0])
- minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
- try:
- faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
- minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- except:
- continue
-
- if len(faces) > 0:
- rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
- return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects]
+ if settings.dnn_model_path is not None:
+ detector = cv2.FaceDetectorYN.create(
+ settings.dnn_model_path,
+ "",
+ (im.width, im.height),
+ 0.8, # score threshold
+ 0.3, # nms threshold
+ 5000 # keep top k before nms
+ )
+ faces = detector.detect(np.array(im))
+ results = []
+ if faces[1] is not None:
+ for face in faces[1]:
+ x = face[0]
+ y = face[1]
+ w = face[2]
+ h = face[3]
+ results.append(
+ PointOfInterest(
+ int(x + (w * 0.5)), # face focus left/right is center
+ int(y + (h * 0)), # face focus up/down is close to the top of the head
+ size = w,
+ weight = 1/len(faces[1])
+ )
+ )
+ return results
+ else:
+ np_im = np.array(im)
+ gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
+
+ tries = [
+ [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
+ [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
+ ]
+ for t in tries:
+ classifier = cv2.CascadeClassifier(t[0])
+ minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
+ try:
+ faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
+ minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
+ except:
+ continue
+
+ if len(faces) > 0:
+ rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
+ return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
return []
@@ -161,7 +206,7 @@ def image_corner_points(im, settings):
np_im,
maxCorners=100,
qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.07,
+ minDistance=min(grayscale.width, grayscale.height)*0.03,
useHarrisDetector=False,
)
@@ -171,7 +216,7 @@ def image_corner_points(im, settings):
focal_points = []
for point in points:
x, y = point.ravel()
- focal_points.append(PointOfInterest(x, y, size=4))
+ focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
return focal_points
@@ -205,17 +250,22 @@ def image_entropy_points(im, settings):
x_mid = int(crop_best[0] + settings.crop_width/2)
y_mid = int(crop_best[1] + settings.crop_height/2)
- return [PointOfInterest(x_mid, y_mid, size=25)]
+ return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
def image_entropy(im):
# greyscale image entropy
- # band = np.asarray(im.convert("L"))
- band = np.asarray(im.convert("1"), dtype=np.uint8)
+ band = np.asarray(im.convert("L"))
+ # band = np.asarray(im.convert("1"), dtype=np.uint8)
hist, _ = np.histogram(band, bins=range(0, 256))
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
+def centroid(pois):
+ x = [poi.x for poi in pois]
+ y = [poi.y for poi in pois]
+ return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
+
def poi_average(pois, settings):
weight = 0.0
@@ -260,11 +310,12 @@ class PointOfInterest:
class Settings:
- def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False):
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
self.crop_width = crop_width
self.crop_height = crop_height
self.corner_points_weight = corner_points_weight
self.entropy_points_weight = entropy_points_weight
- self.face_points_weight = entropy_points_weight
+ self.face_points_weight = face_points_weight
self.annotate_image = annotate_image
- self.destop_view_image = False
\ No newline at end of file
+ self.destop_view_image = False
+ self.dnn_model_path = dnn_model_path
\ No newline at end of file
--
cgit v1.2.3
From 8b4f32779f28010fc8077e8fcfb85a3205b36bc2 Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Tue, 25 Oct 2022 13:15:08 -0700
Subject: Switch to a continous blend for cond. image.
---
modules/processing.py | 9 ++++++---
modules/shared.py | 2 +-
2 files changed, 7 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 23ee5e02..02292bdc 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -769,9 +769,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# Create another latent image, this time with a masked version of the original input.
conditioning_mask = conditioning_mask.to(image.device)
- conditioning_image = image
- if getattr(self, "inpainting_mask_image", shared.opts.inpainting_mask_image):
- conditioning_image = conditioning_image * (1.0 - conditioning_mask)
+ # Smoothly interpolate between the masked and unmasked latent conditioning image.
+ conditioning_image = torch.lerp(
+ image,
+ image * (1.0 - conditioning_mask),
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
+ )
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
diff --git a/modules/shared.py b/modules/shared.py
index 1d0ff1a1..e0ffb824 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -320,7 +320,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
- "inpainting_mask_image": OptionInfo(True, "Mask original image for conditioning used by inpainting model."),
+ "inpainting_mask_weight": OptionInfo(1.0, "Blend betweeen an unmasked and masked conditioning image for inpainting models.", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
}))
--
cgit v1.2.3
From db8ed5fe5cd6e967d12d43d96b7f83083e58626c Mon Sep 17 00:00:00 2001
From: captin411
Date: Tue, 25 Oct 2022 15:22:29 -0700
Subject: Focal crop UI elements
---
modules/textual_inversion/preprocess.py | 26 +++++++++++++-------------
modules/ui.py | 20 ++++++++++++++++++--
2 files changed, 31 insertions(+), 15 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index a8c17c6f..1e4d4de8 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -13,7 +13,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_entropy_focus=False):
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
try:
if process_caption:
shared.interrogator.load()
@@ -23,7 +23,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_entropy_focus)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
finally:
@@ -35,7 +35,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_entropy_focus=False):
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -139,27 +139,27 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
ratio = (img.height * width) / (img.width * height)
inverse_xy = True
- processing_option_ran = False
+ process_default_resize = True
if process_split and ratio < 1.0 and ratio <= split_threshold:
for splitted in split_pic(img, inverse_xy):
save_pic(splitted, index, existing_caption=existing_caption)
- processing_option_ran = True
+ process_default_resize = False
if process_entropy_focus and img.height != img.width:
autocrop_settings = autocrop.Settings(
crop_width = width,
crop_height = height,
- face_points_weight = 0.9,
- entropy_points_weight = 0.7,
- corner_points_weight = 0.5,
- annotate_image = False
+ face_points_weight = process_focal_crop_face_weight,
+ entropy_points_weight = process_focal_crop_entropy_weight,
+ corner_points_weight = process_focal_crop_edges_weight,
+ annotate_image = process_focal_crop_debug
)
- focal = autocrop.crop_image(img, autocrop_settings)
- save_pic(focal, index, existing_caption=existing_caption)
- processing_option_ran = True
+ for focal in autocrop.crop_image(img, autocrop_settings):
+ save_pic(focal, index, existing_caption=existing_caption)
+ process_default_resize = False
- if not processing_option_ran:
+ if process_default_resize:
img = images.resize_image(1, img, width, height)
save_pic(img, index, existing_caption=existing_caption)
diff --git a/modules/ui.py b/modules/ui.py
index 028eb4e5..95b9c703 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1260,7 +1260,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images')
- process_entropy_focus = gr.Checkbox(label='Create auto focal point crop')
+ process_focal_crop = gr.Checkbox(label='Auto focal point crop')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
@@ -1268,6 +1268,12 @@ def create_ui(wrap_gradio_gpu_call):
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
+ with gr.Row(visible=False) as process_focal_crop_row:
+ process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05)
+ process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.3, minimum=0.0, maximum=1.0, step=0.05)
+ process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
+ process_focal_crop_debug = gr.Checkbox(label='Create debug image')
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1281,6 +1287,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[process_split_extra_row],
)
+ process_focal_crop.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_focal_crop],
+ outputs=[process_focal_crop_row],
+ )
+
with gr.Tab(label="Train"):
gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
")
with gr.Row():
@@ -1368,7 +1380,11 @@ def create_ui(wrap_gradio_gpu_call):
process_caption_deepbooru,
process_split_threshold,
process_overlap_ratio,
- process_entropy_focus,
+ process_focal_crop,
+ process_focal_crop_face_weight,
+ process_focal_crop_entropy_weight,
+ process_focal_crop_edges_weight,
+ process_focal_crop_debug,
],
outputs=[
ti_output,
--
cgit v1.2.3
From 54f0c1482427a5b3f2248b97be55878e742cbcb1 Mon Sep 17 00:00:00 2001
From: captin411
Date: Tue, 25 Oct 2022 16:14:13 -0700
Subject: download better face detection module dynamically
---
modules/textual_inversion/autocrop.py | 20 ++++++++++++++++++++
modules/textual_inversion/preprocess.py | 13 +++++++++++--
2 files changed, 31 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index caaf18c8..01a92b12 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -1,4 +1,5 @@
import cv2
+import requests
import os
from collections import defaultdict
from math import log, sqrt
@@ -293,6 +294,25 @@ def is_square(w, h):
return w == h
+def download_and_cache_models(dirname):
+ download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
+ model_file_name = 'face_detection_yunet.onnx'
+
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ cache_file = os.path.join(dirname, model_file_name)
+ if not os.path.exists(cache_file):
+ print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
+ response = requests.get(download_url)
+ with open(cache_file, "wb") as f:
+ f.write(response.content)
+
+ if os.path.exists(cache_file):
+ return cache_file
+ return None
+
+
class PointOfInterest:
def __init__(self, x, y, weight=1.0, size=10):
self.x = x
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 1e4d4de8..e13b1894 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -7,6 +7,7 @@ import tqdm
import time
from modules import shared, images
+from modules.paths import models_path
from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
if cmd_opts.deepdanbooru:
@@ -146,14 +147,22 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
save_pic(splitted, index, existing_caption=existing_caption)
process_default_resize = False
- if process_entropy_focus and img.height != img.width:
+ if process_focal_crop and img.height != img.width:
+
+ dnn_model_path = None
+ try:
+ dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv"))
+ except Exception as e:
+ print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
+
autocrop_settings = autocrop.Settings(
crop_width = width,
crop_height = height,
face_points_weight = process_focal_crop_face_weight,
entropy_points_weight = process_focal_crop_entropy_weight,
corner_points_weight = process_focal_crop_edges_weight,
- annotate_image = process_focal_crop_debug
+ annotate_image = process_focal_crop_debug,
+ dnn_model_path = dnn_model_path,
)
for focal in autocrop.crop_image(img, autocrop_settings):
save_pic(focal, index, existing_caption=existing_caption)
--
cgit v1.2.3
From df0c5ea29d7f0c682ac81f184f3e482a6450d018 Mon Sep 17 00:00:00 2001
From: captin411
Date: Tue, 25 Oct 2022 17:06:59 -0700
Subject: update default weights
---
modules/textual_inversion/autocrop.py | 16 ++++++++--------
modules/ui.py | 2 +-
2 files changed, 9 insertions(+), 9 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index 01a92b12..9859974a 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -71,9 +71,9 @@ def crop_image(im, settings):
return results
def focal_point(im, settings):
- corner_points = image_corner_points(im, settings)
- entropy_points = image_entropy_points(im, settings)
- face_points = image_face_points(im, settings)
+ corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
+ entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
+ face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
pois = []
@@ -144,7 +144,7 @@ def image_face_points(im, settings):
settings.dnn_model_path,
"",
(im.width, im.height),
- 0.8, # score threshold
+ 0.9, # score threshold
0.3, # nms threshold
5000 # keep top k before nms
)
@@ -159,7 +159,7 @@ def image_face_points(im, settings):
results.append(
PointOfInterest(
int(x + (w * 0.5)), # face focus left/right is center
- int(y + (h * 0)), # face focus up/down is close to the top of the head
+ int(y + (h * 0.33)), # face focus up/down is close to the top of the head
size = w,
weight = 1/len(faces[1])
)
@@ -207,7 +207,7 @@ def image_corner_points(im, settings):
np_im,
maxCorners=100,
qualityLevel=0.04,
- minDistance=min(grayscale.width, grayscale.height)*0.03,
+ minDistance=min(grayscale.width, grayscale.height)*0.06,
useHarrisDetector=False,
)
@@ -256,8 +256,8 @@ def image_entropy_points(im, settings):
def image_entropy(im):
# greyscale image entropy
- band = np.asarray(im.convert("L"))
- # band = np.asarray(im.convert("1"), dtype=np.uint8)
+ # band = np.asarray(im.convert("L"))
+ band = np.asarray(im.convert("1"), dtype=np.uint8)
hist, _ = np.histogram(band, bins=range(0, 256))
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
diff --git a/modules/ui.py b/modules/ui.py
index 95b9c703..095200a8 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1270,7 +1270,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row(visible=False) as process_focal_crop_row:
process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05)
- process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.3, minimum=0.0, maximum=1.0, step=0.05)
+ process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_focal_crop_debug = gr.Checkbox(label='Create debug image')
--
cgit v1.2.3
From 2f4c91894d4c0a055c1069b2fda0e4da8fcda188 Mon Sep 17 00:00:00 2001
From: guaneec
Date: Wed, 26 Oct 2022 12:10:30 +0800
Subject: Remove activation from final layer of HNs
---
modules/hypernetworks/hypernetwork.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index d647ea55..54346b64 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -41,8 +41,8 @@ class HypernetworkModule(torch.nn.Module):
# Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
- # Add an activation func
- if activation_func == "linear" or activation_func is None:
+ # Add an activation func except last layer
+ if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 3:
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
@@ -53,7 +53,7 @@ class HypernetworkModule(torch.nn.Module):
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
- # Add dropout expect last layer
+ # Add dropout except last layer
if use_dropout and i < len(layer_structure) - 3:
linears.append(torch.nn.Dropout(p=0.3))
--
cgit v1.2.3
From c702d4d0df21790199d199818f25c449213ffe0f Mon Sep 17 00:00:00 2001
From: guaneec
Date: Wed, 26 Oct 2022 13:43:04 +0800
Subject: Fix off-by-one
---
modules/hypernetworks/hypernetwork.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 54346b64..3ce85bb5 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -42,7 +42,7 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# Add an activation func except last layer
- if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 3:
+ if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 2:
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
@@ -54,7 +54,7 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout except last layer
- if use_dropout and i < len(layer_structure) - 3:
+ if use_dropout and i < len(layer_structure) - 2:
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
--
cgit v1.2.3
From de096d0ce752c96e45508dcc7b9e84f7dbe10cca Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Tue, 25 Oct 2022 14:48:49 +0900
Subject: Weight initialization and More activation func
add weight init
add weight init option in create_hypernetwork
fstringify hypernet info
save weight initialization info for further debugging
fill bias with zero for He/Xavier
initialize LayerNorm with Normal
fix loading weight_init
---
modules/hypernetworks/hypernetwork.py | 47 ++++++++++++++++++++++++++++-------
modules/hypernetworks/ui.py | 4 ++-
modules/ui.py | 4 ++-
3 files changed, 44 insertions(+), 11 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index d647ea55..afbcdff8 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -5,6 +5,7 @@ import html
import os
import sys
import traceback
+import inspect
import modules.textual_inversion.dataset
import torch
@@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
+from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
from collections import defaultdict, deque
from statistics import stdev, mean
+
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module):
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
+ "tanh": torch.nn.Tanh,
+ "sigmoid": torch.nn.Sigmoid,
}
+ activation_dict.update({cls_name: cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
- def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module):
else:
for layer in self.linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
- layer.weight.data.normal_(mean=0.0, std=0.01)
- layer.bias.data.zero_()
-
+ w, b = layer.weight.data, layer.bias.data
+ if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
+ normal_(w, mean=0.0, std=0.01)
+ normal_(b, mean=0.0, std=0.005)
+ elif weight_init == 'XavierUniform':
+ xavier_uniform_(w)
+ zeros_(b)
+ elif weight_init == 'XavierNormal':
+ xavier_normal_(w)
+ zeros_(b)
+ elif weight_init == 'KaimingUniform':
+ kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+ zeros_(b)
+ elif weight_init == 'KaimingNormal':
+ kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
+ zeros_(b)
+ else:
+ raise KeyError(f"Key {weight_init} is not defined as initialization!")
self.to(devices.device)
def fix_old_state_dict(self, state_dict):
@@ -105,7 +126,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
self.filename = None
self.name = name
self.layers = {}
@@ -114,13 +135,14 @@ class Hypernetwork:
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
self.activation_func = activation_func
+ self.weight_init = weight_init
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
)
def weights(self):
@@ -144,6 +166,7 @@ class Hypernetwork:
state_dict['layer_structure'] = self.layer_structure
state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['weight_initialization'] = self.weight_init
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -158,15 +181,21 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ print(self.layer_structure)
self.activation_func = state_dict.get('activation_func', None)
+ print(f"Activation function is {self.activation_func}")
+ self.weight_init = state_dict.get('weight_initialization', 'Normal')
+ print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False)
+ print(f"Dropout usage is set to {self.use_dropout}" )
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
)
self.name = state_dict.get('name', self.name)
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 2b472d87..2c6c0470 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -8,8 +8,9 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
+keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
activation_func=activation_func,
+ weight_init=weight_init,
add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
)
diff --git a/modules/ui.py b/modules/ui.py
index 03528968..8e343258 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1238,7 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
+ new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
@@ -1342,6 +1343,7 @@ def create_ui(wrap_gradio_gpu_call):
overwrite_old_hypernetwork,
new_hypernetwork_layer_structure,
new_hypernetwork_activation_func,
+ new_hypernetwork_initialization_option,
new_hypernetwork_add_layer_norm,
new_hypernetwork_use_dropout
],
--
cgit v1.2.3
From 7207e3bf49ed000464d288cd67e02f0ba8614dc3 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Tue, 25 Oct 2022 15:24:59 +0900
Subject: remove duplicate keys and lowercase
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index afbcdff8..842b6447 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -32,7 +32,7 @@ class HypernetworkModule(torch.nn.Module):
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
}
- activation_dict.update({cls_name: cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
super().__init__()
--
cgit v1.2.3
From cbb857b675cf0f169b21515c29da492b513cc8c4 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 26 Oct 2022 09:44:02 +0300
Subject: enable creating embedding with --medvram
---
modules/textual_inversion/textual_inversion.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 529ed3e2..647ffe3e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
+ with devices.autocast():
+ cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
+
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
--
cgit v1.2.3
From db9ab1a46b5ad4d36ecce76dfee04b7164249829 Mon Sep 17 00:00:00 2001
From: Stephen
Date: Mon, 24 Oct 2022 11:16:07 -0400
Subject: [Bugfix][API] - Fix API response for colab users
---
modules/api/api.py | 17 +++++++++++++----
modules/api/models.py | 10 ++++++----
2 files changed, 19 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index a860a964..ba890243 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -7,6 +7,7 @@ import uvicorn
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
+from typing import List
import json
import io
import base64
@@ -15,12 +16,12 @@ from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel):
- images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class ImageToImageResponse(BaseModel):
- images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
@@ -41,6 +42,9 @@ class Api:
# convert base64 to PIL image
return Image.open(io.BytesIO(imgdata))
+ def __processed_info_to_json(self, processed):
+ return json.dumps(processed.info)
+
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -65,7 +69,7 @@ class Api:
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
- return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
+ return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
@@ -111,7 +115,12 @@ class Api:
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
- return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))
+ if (not img2imgreq.include_init_images):
+ # remove img2imgreq.init_images and img2imgreq.mask
+ img2imgreq.init_images = None
+ img2imgreq.mask = None
+
+ return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
def extrasapi(self):
raise NotImplementedError
diff --git a/modules/api/models.py b/modules/api/models.py
index f551fa35..c6d43606 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -31,6 +31,7 @@ class ModelDef(BaseModel):
field_alias: str
field_type: Any
field_value: Any
+ field_exclude: bool = False
class PydanticModelGenerator:
@@ -68,7 +69,7 @@ class PydanticModelGenerator:
field=underscore(k),
field_alias=k,
field_type=field_type_generator(k, v),
- field_value=v.default
+ field_value=v.default,
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
@@ -78,7 +79,8 @@ class PydanticModelGenerator:
field=underscore(fields["key"]),
field_alias=fields["key"],
field_type=fields["type"],
- field_value=fields["default"]))
+ field_value=fields["default"],
+ field_exclude=fields["exclude"] if "exclude" in fields else False))
def generate_model(self):
"""
@@ -86,7 +88,7 @@ class PydanticModelGenerator:
from the json and overrides provided at initialization
"""
fields = {
- d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
}
DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True
@@ -102,5 +104,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model()
\ No newline at end of file
--
cgit v1.2.3
From b46c64c6e5b40d69521e4d50e2d35f6a35468129 Mon Sep 17 00:00:00 2001
From: Stephen
Date: Mon, 24 Oct 2022 12:18:54 -0400
Subject: clean
---
modules/api/api.py | 4 ----
modules/api/models.py | 2 +-
2 files changed, 1 insertion(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index ba890243..6e9d6097 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -42,9 +42,6 @@ class Api:
# convert base64 to PIL image
return Image.open(io.BytesIO(imgdata))
- def __processed_info_to_json(self, processed):
- return json.dumps(processed.info)
-
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -116,7 +113,6 @@ class Api:
b64images.append(base64.b64encode(buffer.getvalue()))
if (not img2imgreq.include_init_images):
- # remove img2imgreq.init_images and img2imgreq.mask
img2imgreq.init_images = None
img2imgreq.mask = None
diff --git a/modules/api/models.py b/modules/api/models.py
index c6d43606..079e33d9 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -69,7 +69,7 @@ class PydanticModelGenerator:
field=underscore(k),
field_alias=k,
field_type=field_type_generator(k, v),
- field_value=v.default,
+ field_value=v.default
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
--
cgit v1.2.3
From 146856f66d7e06a762f5ef5bf61a226057de6757 Mon Sep 17 00:00:00 2001
From: Milly
Date: Tue, 25 Oct 2022 06:21:31 +0900
Subject: images: allow nested bracket in filename pattern
---
modules/images.py | 11 ++++-------
1 file changed, 4 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index 286de2ae..ed448a8a 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -277,7 +277,7 @@ invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
-re_pattern = re.compile(r"([^\[\]]+|\[([^]]+)]|[\[\]]*)")
+re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128
@@ -362,9 +362,9 @@ class FilenameGenerator:
for m in re_pattern.finditer(x):
text, pattern = m.groups()
+ res += text
if pattern is None:
- res += text
continue
pattern_args = []
@@ -385,12 +385,9 @@ class FilenameGenerator:
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- if replacement is None:
- res += f'[{pattern}]'
- else:
+ if replacement is not None:
res += str(replacement)
-
- continue
+ continue
res += f'[{pattern}]'
--
cgit v1.2.3
From 877d94f97ca5491d8779440769b191e0dcd32c8e Mon Sep 17 00:00:00 2001
From: guaneec
Date: Wed, 26 Oct 2022 14:50:58 +0800
Subject: Back compatibility
---
modules/hypernetworks/hypernetwork.py | 17 ++++++++++-------
1 file changed, 10 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3ce85bb5..dd317085 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -28,7 +28,7 @@ class HypernetworkModule(torch.nn.Module):
"swish": torch.nn.Hardswish,
}
- def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -42,7 +42,7 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# Add an activation func except last layer
- if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 2:
+ if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
@@ -105,7 +105,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
self.filename = None
self.name = name
self.layers = {}
@@ -116,11 +116,12 @@ class Hypernetwork:
self.activation_func = activation_func
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
+ self.activate_output = activate_output
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
)
def weights(self):
@@ -147,6 +148,7 @@ class Hypernetwork:
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+ state_dict['activate_output'] = self.activate_output
torch.save(state_dict, filename)
@@ -161,12 +163,13 @@ class Hypernetwork:
self.activation_func = state_dict.get('activation_func', None)
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.use_dropout = state_dict.get('use_dropout', False)
+ self.activate_output = state_dict.get('activate_output', True)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
)
self.name = state_dict.get('name', self.name)
--
cgit v1.2.3
From 757264c453eca533ee1c9ea7e9d9b45a009367d7 Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Tue, 25 Oct 2022 23:39:21 +0900
Subject: default_time_format if format is blank
---
modules/images.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index ed448a8a..bfc2ba06 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -343,7 +343,7 @@ class FilenameGenerator:
def datetime(self, *args):
time_datetime = datetime.datetime.now()
- time_format = args[0] if len(args) > 0 else self.default_time_format
+ time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
except pytz.exceptions.UnknownTimeZoneError as _:
--
cgit v1.2.3
From 9d82c351ac36d1511f5f65b24443c60250ee3e9e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 26 Oct 2022 09:56:25 +0300
Subject: fix typo in on_save_imaged/on_image_saved; hope no extension is using
it yet
---
modules/script_callbacks.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index dc520abc..6803d57b 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -90,7 +90,7 @@ def on_ui_settings(callback):
add_callback(callbacks_ui_settings, callback)
-def on_save_imaged(callback):
+def on_image_saved(callback):
"""register a function to be called after modules.images.save_image is called.
The callback is called with three arguments:
- p - procesing object (or a dummy object with same fields if the image is saved using save button)
--
cgit v1.2.3
From 91bb35b1e6842b30ce7553009c8ecea3643de8d2 Mon Sep 17 00:00:00 2001
From: guaneec
Date: Wed, 26 Oct 2022 15:00:03 +0800
Subject: Merge fix
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index eab8b32f..bd171793 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -190,7 +190,7 @@ class Hypernetwork:
print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False)
print(f"Layer norm is set to {self.add_layer_norm}")
- self.use_dropout = state_dict.get('use_dropout', False
+ self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
--
cgit v1.2.3
From cb49800c08a9f6619733250401952e5571dc12f8 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Tue, 25 Oct 2022 01:39:59 -0700
Subject: img2img, use smartphone photos' EXIF orientation
---
modules/img2img.py | 8 ++++++++
1 file changed, 8 insertions(+)
(limited to 'modules')
diff --git a/modules/img2img.py b/modules/img2img.py
index 8d9f7cf9..9c0cf23e 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -39,6 +39,8 @@ def process_batch(p, input_dir, output_dir, args):
break
img = Image.open(image)
+ # Use the EXIF orientation of photos taken by smartphones.
+ img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
proc = modules.scripts.scripts_img2img.run(p, *args)
@@ -61,19 +63,25 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
is_batch = mode == 2
if is_inpaint:
+ # Drawn mask
if mask_mode == 0:
image = init_img_with_mask['image']
mask = init_img_with_mask['mask']
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
image = image.convert('RGB')
+ # Uploaded mask
else:
image = init_img_inpaint
mask = init_mask_inpaint
+ # No mask
else:
image = init_img
mask = None
+ # Use the EXIF orientation of photos taken by smartphones.
+ image = ImageOps.exif_transpose(image)
+
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img(
--
cgit v1.2.3
From a524d137d0a89bb19a6676dc9b8fbb5d1b580678 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Mon, 24 Oct 2022 23:48:05 -0700
Subject: patch bug (SeverianVoid's comment on 5245c7a)
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 842b6447..8113b35b 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -487,7 +487,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None:
shared.state.current_image = image
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
--
cgit v1.2.3
From c2dc9bfa89070b8e1d857f8773a790b752f1b709 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Mon, 24 Oct 2022 23:22:58 -0700
Subject: Implement PR #3189 but for embeddings.
---
modules/textual_inversion/textual_inversion.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 647ffe3e..22c7b54b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin
-from modules import shared, devices, sd_hijack, processing, sd_models
+from modules import shared, devices, sd_hijack, processing, sd_models, images
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -247,6 +247,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
last_saved_file = ""
last_saved_image = ""
+ forced_filename = ""
embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0
@@ -296,8 +297,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
})
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
-
+ forced_filename = f'{embedding_name}-{embedding.step}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
@@ -353,8 +354,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
- image.save(last_saved_image)
-
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
--
cgit v1.2.3
From 4875a6c217df5cc06ee2bf11fb645b172c7156a8 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Mon, 24 Oct 2022 23:38:07 -0700
Subject: Implement PR #3309 but for embeddings.
---
modules/textual_inversion/textual_inversion.py | 9 ++++++++-
1 file changed, 8 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 22c7b54b..4921bd01 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -167,6 +167,8 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
@@ -287,7 +289,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
- last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
+ # Before saving, change name to match current checkpoint.
+ embedding.name = f'{embedding_name}-{embedding.step}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
@@ -374,6 +378,9 @@ Last saved image: {html.escape(last_saved_image)}
embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None
+ # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
+ embedding.name = embedding_name
+ filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt')
embedding.save(filename)
return embedding, filename
--
cgit v1.2.3
From f4e14642173a04723200b131deb417c6c79cab17 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Tue, 25 Oct 2022 00:04:25 -0700
Subject: Implement PR #3625 but for embeddings.
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4921bd01..4fcebe74 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -358,7 +358,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
--
cgit v1.2.3
From b6a8bb123bd519736306417399f6441e504f1e8b Mon Sep 17 00:00:00 2001
From: guaneec
Date: Wed, 26 Oct 2022 15:15:19 +0800
Subject: Fix merge
---
modules/hypernetworks/hypernetwork.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index bd171793..2997cead 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -60,7 +60,7 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout except last layer
- if use_dropout and i < len(layer_structure) - 2:
+ if use_dropout and i < len(layer_structure) - 3:
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
@@ -126,7 +126,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False)
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False):
self.filename = None
self.name = name
self.layers = {}
--
cgit v1.2.3
From 1e428238db4e399b7a06ad5251cb16eef23a014d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 26 Oct 2022 11:47:07 +0300
Subject: add override_settings to API as an alternative to #3629
---
modules/processing.py | 25 ++++++++++++++++++++-----
modules/shared.py | 4 ++--
2 files changed, 22 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index c61bbfbd..4efba946 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -77,9 +77,8 @@ def get_correct_sampler(p):
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
-
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -109,13 +108,14 @@ class StableDiffusionProcessing():
self.do_not_reload_embeddings = do_not_reload_embeddings
self.paste_to = None
self.color_corrections = None
- self.denoising_strength: float = 0
+ self.denoising_strength: float = denoising_strength
self.sampler_noise_scheduler_override = None
- self.ddim_discretize = opts.ddim_discretize
+ self.ddim_discretize = ddim_discretize or opts.ddim_discretize
self.s_churn = s_churn or opts.s_churn
self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise
+ self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
if not seed_enable_extras:
self.subseed = -1
@@ -129,7 +129,6 @@ class StableDiffusionProcessing():
self.all_seeds = None
self.all_subseeds = None
-
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -351,6 +350,22 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
def process_images(p: StableDiffusionProcessing) -> Processed:
+ stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
+
+ try:
+ for k, v in p.override_settings.items():
+ opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
+
+ res = process_images_inner(p)
+
+ finally:
+ for k, v in stored_opts.items():
+ opts.data[k] = v
+
+ return res
+
+
+def process_images_inner(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if type(p.prompt) == list:
diff --git a/modules/shared.py b/modules/shared.py
index 308fccce..1a9b8289 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -84,7 +84,7 @@ parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load mod
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
cmd_opts = parser.parse_args()
-restricted_opts = [
+restricted_opts = {
"samples_filename_pattern",
"directories_filename_pattern",
"outdir_samples",
@@ -94,7 +94,7 @@ restricted_opts = [
"outdir_grids",
"outdir_txt2img_grids",
"outdir_save",
-]
+}
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
--
cgit v1.2.3
From 0cd74602531a40f72d1a75b471a8a9166135d333 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 26 Oct 2022 13:12:44 +0300
Subject: add script callback for before image save and change callback for
after image save to use a class with parameters
---
modules/images.py | 42 ++++++++++++++++++++++-----------------
modules/script_callbacks.py | 48 +++++++++++++++++++++++++++++++++++++--------
2 files changed, 64 insertions(+), 26 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index bfc2ba06..7870b5b7 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -451,17 +451,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
"""
namegen = FilenameGenerator(p, seed, prompt)
- if extension == 'png' and opts.enable_pnginfo and info is not None:
- pnginfo = PngImagePlugin.PngInfo()
-
- if existing_info is not None:
- for k, v in existing_info.items():
- pnginfo.add_text(k, str(v))
-
- pnginfo.add_text(pnginfo_section_name, info)
- else:
- pnginfo = None
-
if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
@@ -489,19 +478,27 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if add_number:
basecount = get_next_sequence_number(path, basename)
fullfn = None
- fullfn_without_extension = None
for i in range(500):
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
- fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
if not os.path.exists(fullfn):
break
else:
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
- fullfn_without_extension = os.path.join(path, file_decoration)
else:
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
- fullfn_without_extension = os.path.join(path, forced_filename)
+
+ pnginfo = existing_info or {}
+ if info is not None:
+ pnginfo[pnginfo_section_name] = info
+
+ params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
+ script_callbacks.before_image_saved_callback(params)
+
+ image = params.image
+ fullfn = params.filename
+ info = params.pnginfo.get(pnginfo_section_name, None)
+ fullfn_without_extension, extension = os.path.splitext(params.filename)
def exif_bytes():
return piexif.dump({
@@ -510,12 +507,20 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
},
})
- if extension.lower() in ("jpg", "jpeg", "webp"):
+ if extension.lower() == '.png':
+ pnginfo_data = PngImagePlugin.PngInfo()
+ for k, v in params.pnginfo.items():
+ pnginfo_data.add_text(k, str(v))
+
+ image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
+
+ elif extension.lower() in (".jpg", ".jpeg", ".webp"):
image.save(fullfn, quality=opts.jpeg_quality)
+
if opts.enable_pnginfo and info is not None:
piexif.insert(exif_bytes(), fullfn)
else:
- image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
+ image.save(fullfn, quality=opts.jpeg_quality)
target_side_length = 4000
oversize = image.width > target_side_length or image.height > target_side_length
@@ -538,7 +543,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else:
txt_fullfn = None
- script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn)
+ script_callbacks.image_saved_callback(params)
+
return fullfn, txt_fullfn
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 6803d57b..6ea58d61 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -9,15 +9,34 @@ def report_exception(c, job):
print(traceback.format_exc(), file=sys.stderr)
+class ImageSaveParams:
+ def __init__(self, image, p, filename, pnginfo):
+ self.image = image
+ """the PIL image itself"""
+
+ self.p = p
+ """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
+
+ self.filename = filename
+ """name of file that the image would be saved to"""
+
+ self.pnginfo = pnginfo
+ """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
+
+
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
+callbacks_before_image_saved = []
callbacks_image_saved = []
+
def clear_callbacks():
callbacks_model_loaded.clear()
callbacks_ui_tabs.clear()
+ callbacks_ui_settings.clear()
+ callbacks_before_image_saved.clear()
callbacks_image_saved.clear()
@@ -49,10 +68,18 @@ def ui_settings_callback():
report_exception(c, 'ui_settings_callback')
-def image_saved_callback(image, p, fullfn, txt_fullfn):
+def before_image_saved_callback(params: ImageSaveParams):
for c in callbacks_image_saved:
try:
- c.callback(image, p, fullfn, txt_fullfn)
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'before_image_saved_callback')
+
+
+def image_saved_callback(params: ImageSaveParams):
+ for c in callbacks_image_saved:
+ try:
+ c.callback(params)
except Exception:
report_exception(c, 'image_saved_callback')
@@ -64,7 +91,6 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
-
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
@@ -90,11 +116,17 @@ def on_ui_settings(callback):
add_callback(callbacks_ui_settings, callback)
+def on_before_image_saved(callback):
+ """register a function to be called before an image is saved to a file.
+ The callback is called with one argument:
+ - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
+ """
+ add_callback(callbacks_before_image_saved, callback)
+
+
def on_image_saved(callback):
- """register a function to be called after modules.images.save_image is called.
- The callback is called with three arguments:
- - p - procesing object (or a dummy object with same fields if the image is saved using save button)
- - fullfn - image filename
- - txt_fullfn - text file with parameters; may be None
+ """register a function to be called after an image is saved to a file.
+ The callback is called with one argument:
+ - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
add_callback(callbacks_image_saved, callback)
--
cgit v1.2.3
From 7bd8581e461623932ffbd5762ee931ee51f798db Mon Sep 17 00:00:00 2001
From: Sihan Wang <31711261+shwang95@users.noreply.github.com>
Date: Wed, 26 Oct 2022 20:32:55 +0800
Subject: Fix error caused by EXIF transpose when using custom scripts
Some custom scripts read image directly and no need to select image in UI, this will cause error.
---
modules/img2img.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/img2img.py b/modules/img2img.py
index 9c0cf23e..86a19f37 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -80,7 +80,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
mask = None
# Use the EXIF orientation of photos taken by smartphones.
- image = ImageOps.exif_transpose(image)
+ if image is not None:
+ image = ImageOps.exif_transpose(image)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
--
cgit v1.2.3
From b2e0d8ba789b345145436f6e960a3f0a896a6643 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Wed, 26 Oct 2022 09:54:26 -0300
Subject: Remove folder endpoint
---
modules/api/api.py | 9 ---------
modules/api/models.py | 6 +-----
2 files changed, 1 insertion(+), 14 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index ca289d9f..49c213ea 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -32,7 +32,6 @@ class Api:
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
- self.app.add_api_route("/sdapi/v1/extra-folder-images", self.extras_folder_processing_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -126,14 +125,6 @@ class Api:
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
- def extras_folder_processing_api(self, req:ExtrasFoldersRequest):
- reqDict = setUpscalers(req)
-
- with self.queue_lock:
- result = run_extras(extras_mode=2, image=None, image_folder=None, **reqDict)
-
- return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
-
def pnginfoapi(self):
raise NotImplementedError
diff --git a/modules/api/models.py b/modules/api/models.py
index 00406368..dd122321 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -148,8 +148,4 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: list[str] = Field(title="Images", description="The generated images in base64 format.")
-
-class ExtrasFoldersRequest(ExtrasBaseRequest):
- input_dir: str = Field(title="Input directory", description="Directory path from where to take the images")
- output_dir: str = Field(title="Output directory", description="Directory path to put the processsed images into")
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
\ No newline at end of file
--
cgit v1.2.3
From 737eb28faca8be2bb996ee0930ec77d1f7ebd939 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 26 Oct 2022 14:45:33 +0100
Subject: typo: cmd_opts.embedding_dir to cmd_opts.embeddings_dir
---
modules/textual_inversion/textual_inversion.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4fcebe74..ff002d3e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -380,7 +380,7 @@ Last saved image: {html.escape(last_saved_image)}
embedding.cached_checksum = None
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
embedding.name = embedding_name
- filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt')
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
embedding.save(filename)
return embedding, filename
--
cgit v1.2.3
From fddb4883f4a408b3464076465e1b0949ebe0fc30 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Wed, 26 Oct 2022 22:33:45 +0800
Subject: prototype progress api
---
modules/api/api.py | 89 +++++++++++++++++++++++++++++++++++++++++++++---------
modules/shared.py | 13 ++++++++
2 files changed, 88 insertions(+), 14 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 6e9d6097..c038f674 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,8 +1,11 @@
+import time
+
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
+from modules import devices
import uvicorn
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
@@ -25,6 +28,37 @@ class ImageToImageResponse(BaseModel):
parameters: Json
info: Json
+class ProgressResponse(BaseModel):
+ progress: float
+ eta_relative: float
+ state: Json
+
+# copy from wrap_gradio_gpu_call of webui.py
+# because queue lock will be acquired in api handlers
+# and time start needs to be set
+# the function has been modified into two parts
+
+def before_gpu_call():
+ devices.torch_gc()
+
+ shared.state.sampling_step = 0
+ shared.state.job_count = -1
+ shared.state.job_no = 0
+ shared.state.job_timestamp = shared.state.get_job_timestamp()
+ shared.state.current_latent = None
+ shared.state.current_image = None
+ shared.state.current_image_sampling_step = 0
+ shared.state.skipped = False
+ shared.state.interrupted = False
+ shared.state.textinfo = None
+ shared.state.time_start = time.time()
+
+
+def after_gpu_call():
+ shared.state.job = ""
+ shared.state.job_count = 0
+
+ devices.torch_gc()
class Api:
def __init__(self, app, queue_lock):
@@ -33,6 +67,7 @@ class Api:
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix
@@ -44,12 +79,12 @@ class Api:
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
-
+
if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
-
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
populate = txt2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
+ "sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
"do_not_save_grid": True
@@ -57,9 +92,11 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
+ before_gpu_call()
with self.queue_lock:
processed = process_images(p)
-
+ after_gpu_call()
+
b64images = []
for i in processed.images:
buffer = io.BytesIO()
@@ -67,30 +104,30 @@ class Api:
b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
-
-
+
+
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
-
+
if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
+ raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images
if init_images is None:
- raise HTTPException(status_code=404, detail="Init image not found")
+ raise HTTPException(status_code=404, detail="Init image not found")
mask = img2imgreq.mask
if mask:
mask = self.__base64_to_image(mask)
-
+
populate = img2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
+ "sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
- "do_not_save_grid": True,
+ "do_not_save_grid": True,
"mask": mask
}
)
@@ -103,9 +140,11 @@ class Api:
p.init_images = imgs
# Override object param
+ before_gpu_call()
with self.queue_lock:
processed = process_images(p)
-
+ after_gpu_call()
+
b64images = []
for i in processed.images:
buffer = io.BytesIO()
@@ -118,6 +157,28 @@ class Api:
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
+ def progressapi(self):
+ # copy from check_progress_call of ui.py
+
+ if shared.state.job_count == 0:
+ return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
+
+ # avoid dividing zero
+ progress = 0.01
+
+ if shared.state.job_count > 0:
+ progress += shared.state.job_no / shared.state.job_count
+ if shared.state.sampling_steps > 0:
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+
+ time_since_start = time.time() - shared.state.time_start
+ eta = (time_since_start/progress)
+ eta_relative = eta-time_since_start
+
+ progress = min(progress, 1)
+
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
+
def extrasapi(self):
raise NotImplementedError
diff --git a/modules/shared.py b/modules/shared.py
index 1a9b8289..00f61898 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -146,6 +146,19 @@ class State:
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
+ def js(self):
+ obj = {
+ "skipped": self.skipped,
+ "interrupted": self.skipped,
+ "job": self.job,
+ "job_count": self.job_count,
+ "job_no": self.job_no,
+ "sampling_step": self.sampling_step,
+ "sampling_steps": self.sampling_steps,
+ }
+
+ return json.dumps(obj)
+
state = State()
--
cgit v1.2.3
From 0dd8480281ffa3e58439a3ce059c02d9f3baa5c7 Mon Sep 17 00:00:00 2001
From: MMaker
Date: Wed, 26 Oct 2022 11:08:44 -0400
Subject: fix: Correct before image saved callback
---
modules/script_callbacks.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 6ea58d61..cedbe7bd 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -69,7 +69,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
- for c in callbacks_image_saved:
+ for c in callbacks_before_image_saved:
try:
c.callback(params)
except Exception:
--
cgit v1.2.3
From 3de036514138d7cdcba9729c975f1683a8e06b16 Mon Sep 17 00:00:00 2001
From: xmodar
Date: Wed, 26 Oct 2022 23:56:11 +0300
Subject: Add id access to scripts list in the css
---
modules/scripts.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/scripts.py b/modules/scripts.py
index 9323af3e..a7f36012 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -236,7 +236,7 @@ class ScriptRunner:
with gr.Group():
create_script_ui(script, inputs, inputs_alwayson)
- dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
+ dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs[0] = dropdown
--
cgit v1.2.3
From 4a4647e0dfc812783db7fa993d486b031f098ef8 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Thu, 27 Oct 2022 13:36:11 +0800
Subject: create send to buttons in one module
---
modules/generation_parameters_copypaste.py | 86 +++++++-
modules/shared.py | 1 +
modules/ui.py | 344 +++++++++--------------------
3 files changed, 184 insertions(+), 247 deletions(-)
(limited to 'modules')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index f73647da..2b80737a 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -3,13 +3,16 @@ import re
import gradio as gr
from modules.shared import script_path
from modules import shared
+import tempfile
+from PIL import Image, PngImagePlugin
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update())
-
+paste_fields = {}
+bind_list = []
def quote(text):
if ',' not in str(text):
@@ -20,6 +23,81 @@ def quote(text):
text = text.replace('"', '\\"')
return f'"{text}"'
+def image_from_url_text(filedata):
+ if type(filedata) == dict and filedata["is_file"]:
+ filename = filedata["name"]
+ tempdir = os.path.normpath(tempfile.gettempdir())
+ normfn = os.path.normpath(filename)
+ assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
+
+ return Image.open(filename)
+
+ if type(filedata) == list:
+ if len(filedata) == 0:
+ return None
+
+ filedata = filedata[0]
+
+ if filedata.startswith("data:image/png;base64,"):
+ filedata = filedata[len("data:image/png;base64,"):]
+
+ filedata = base64.decodebytes(filedata.encode('utf-8'))
+ image = Image.open(io.BytesIO(filedata))
+ return image
+
+def add_paste_fields(tabname, init_img, fields):
+ paste_fields[tabname] = {"init_img":init_img, "fields": fields}
+
+def create_buttons(tabs_list):
+ buttons = {}
+ for tab in tabs_list:
+ buttons[tab] = gr.Button(f"Send to {tab}")
+ return buttons
+
+#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
+def bind_buttons(buttons, send_image, send_generate_info):
+ bind_list.append([buttons, send_image, send_generate_info])
+
+def run_bind():
+ for buttons, send_image, send_generate_info in bind_list:
+ for tab in buttons:
+ button = buttons[tab]
+ if send_image and paste_fields[tab]["init_img"]:
+ if type(send_image) == gr.Gallery:
+ button.click(
+ fn=lambda x: image_from_url_text(x),
+ _js="extract_image_from_gallery",
+ inputs=[send_image],
+ outputs=[paste_fields[tab]["init_img"]],
+ )
+ else:
+ button.click(
+ fn=lambda x:x,
+ inputs=[send_image],
+ outputs=[paste_fields[tab]["init_img"]],
+ )
+
+ if send_generate_info and paste_fields[tab]["fields"] is not None:
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2']
+ if shared.opts.send_seed:
+ paste_field_names += ["Seed"]
+ if send_generate_info in paste_fields:
+ button.click(
+ fn=lambda *x:x,
+ inputs=[field for field,name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
+ outputs=[field for field,name in paste_fields[tab]["fields"] if name in paste_field_names],
+ )
+
+ else:
+ connect_paste(button, [(field, name) for field, name in paste_fields[tab]["fields"] if name in paste_field_names], send_generate_info)
+
+ button.click(
+ fn=None,
+ _js=f"switch_to_{tab}",
+ inputs=None,
+ outputs=None,
+ )
+
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -67,8 +145,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res
-
-def connect_paste(button, paste_fields, input_comp, js=None):
+def connect_paste(button, paste_fields, input_comp):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt")
@@ -106,7 +183,8 @@ def connect_paste(button, paste_fields, input_comp, js=None):
button.click(
fn=paste_func,
- _js=js,
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
)
+
+
diff --git a/modules/shared.py b/modules/shared.py
index f8b13b06..3ade2afa 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -279,6 +279,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
+ "send_seed": OptionInfo(False, "Send seed when sending prompt or image to other interface"),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
diff --git a/modules/ui.py b/modules/ui.py
index 3e5b84d2..ccba14b6 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -35,7 +35,7 @@ if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
import modules.codeformer_model
-import modules.generation_parameters_copypaste
+import modules.generation_parameters_copypaste as parameters_copypaste
import modules.gfpgan_model
import modules.hypernetworks.ui
import modules.ldsr_model
@@ -49,14 +49,11 @@ from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
import modules.textual_inversion.ui
import modules.hypernetworks.ui
+from modules.generation_parameters_copypaste import image_from_url_text
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
-txt2img_paste_fields = []
-img2img_paste_fields = []
-init_img_components = {}
-
if not cmd_opts.share and not cmd_opts.listen:
# fix gradio phoning home
@@ -99,37 +96,11 @@ def plaintext_to_html(text):
text = "" + " \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "
"
return text
-
-def image_from_url_text(filedata):
- if type(filedata) == dict and filedata["is_file"]:
- filename = filedata["name"]
- tempdir = os.path.normpath(tempfile.gettempdir())
- normfn = os.path.normpath(filename)
- assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
-
- return Image.open(filename)
-
- if type(filedata) == list:
- if len(filedata) == 0:
- return None
-
- filedata = filedata[0]
-
- if filedata.startswith("data:image/png;base64,"):
- filedata = filedata[len("data:image/png;base64,"):]
-
- filedata = base64.decodebytes(filedata.encode('utf-8'))
- image = Image.open(io.BytesIO(filedata))
- return image
-
-
def send_gradio_gallery_to_image(x):
if len(x) == 0:
return None
-
return image_from_url_text(x[0])
-
def save_files(js_data, images, do_make_zip, index):
import csv
filenames = []
@@ -193,7 +164,6 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
-
def save_pil_to_file(pil_image, dir=None):
use_metadata = False
metadata = PngImagePlugin.PngInfo()
@@ -626,6 +596,83 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
)
return refresh_button
+def create_output_panel(tabname, outdir):
+ def open_folder(f):
+ if not os.path.exists(f):
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+ return
+ elif not os.path.isdir(f):
+ print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+ return
+
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ else:
+ sp.Popen(["xdg-open", path])
+
+ with gr.Column(variant='panel'):
+ with gr.Group():
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+
+ generation_info = None
+ with gr.Column():
+ with gr.Row():
+ if tabname != "extras":
+ save = gr.Button('Save')
+
+ buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
+ button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
+ open_folder = gr.Button(folder_symbol, elem_id=button_id)
+
+ open_folder.click(
+ fn=lambda: open_folder(opts.outdir_samples or outdir),
+ inputs=[],
+ outputs=[],
+ )
+
+ if tabname != "extras":
+ with gr.Row():
+ do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
+
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
+
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
+
+ save.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ do_make_zip,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_info,
+ html_info,
+ html_info,
+ ]
+ )
+ else:
+ html_info_x = gr.HTML()
+ html_info = gr.HTML()
+ parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
+
+
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
@@ -676,31 +723,10 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
- with gr.Column(variant='panel'):
-
- with gr.Group():
- txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
- txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
-
- with gr.Column():
- with gr.Row():
- save = gr.Button('Save')
- send_to_img2img = gr.Button('Send to img2img')
- send_to_inpaint = gr.Button('Send to inpaint')
- send_to_extras = gr.Button('Send to extras')
- button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
- open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
-
- with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
-
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
-
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+
+ txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -755,24 +781,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=lambda x: gr_show(x),
inputs=[enable_hr],
outputs=[hr_options],
- )
-
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
- inputs=[
- generation_info,
- txt2img_gallery,
- do_make_zip,
- html_info,
- ],
- outputs=[
- download_files,
- html_info,
- html_info,
- html_info,
- ]
- )
+ )
roll.click(
fn=roll_artist,
@@ -785,8 +794,7 @@ def create_ui(wrap_gradio_gpu_call):
]
)
- global txt2img_paste_fields
- txt2img_paste_fields = [
+ parameters_copypaste.add_paste_fields("txt2img", None, [
(txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"),
(steps, "Steps"),
@@ -807,7 +815,7 @@ def create_ui(wrap_gradio_gpu_call):
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
*modules.scripts.scripts_txt2img.infotext_fields
- ]
+ ])
txt2img_preview_params = [
txt2img_prompt,
@@ -894,30 +902,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
- with gr.Column(variant='panel'):
-
- with gr.Group():
- img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
- img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
-
- with gr.Column():
- with gr.Row():
- save = gr.Button('Save')
- img2img_send_to_img2img = gr.Button('Send to img2img')
- img2img_send_to_inpaint = gr.Button('Send to inpaint')
- img2img_send_to_extras = gr.Button('Send to extras')
- button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
- open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
-
- with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
-
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
-
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -1004,24 +989,8 @@ def create_ui(wrap_gradio_gpu_call):
fn=interrogate_deepbooru,
inputs=[init_img],
outputs=[img2img_prompt],
- )
-
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
- inputs=[
- generation_info,
- img2img_gallery,
- do_make_zip,
- html_info,
- ],
- outputs=[
- download_files,
- html_info,
- html_info,
- html_info,
- ]
)
+
roll.click(
fn=roll_artist,
@@ -1056,7 +1025,8 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[prompt, negative_prompt, style1, style2],
)
- global img2img_paste_fields
+ token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+
img2img_paste_fields = [
(img2img_prompt, "Prompt"),
(img2img_negative_prompt, "Negative prompt"),
@@ -1075,7 +1045,9 @@ def create_ui(wrap_gradio_gpu_call):
(denoising_strength, "Denoising strength"),
*modules.scripts.scripts_img2img.infotext_fields
]
- token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+ parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
+ parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
+
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
@@ -1122,15 +1094,8 @@ def create_ui(wrap_gradio_gpu_call):
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
- with gr.Column(variant='panel'):
- result_images = gr.Gallery(label="Result", show_label=False)
- html_info_x = gr.HTML()
- html_info = gr.HTML()
- extras_send_to_img2img = gr.Button('Send to img2img')
- extras_send_to_inpaint = gr.Button('Send to inpaint')
- button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
- open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
+ result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
submit.click(
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
@@ -1160,23 +1125,8 @@ def create_ui(wrap_gradio_gpu_call):
html_info,
]
)
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
- extras_send_to_img2img.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_img2img",
- inputs=[result_images],
- outputs=[init_img],
- )
-
- extras_send_to_inpaint.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_inpaint",
- inputs=[result_images],
- outputs=[init_img_with_mask],
- )
-
- global init_img_components
- init_img_components = {"img2img":init_img, "inpaint":init_img_with_mask, "extras":extras_image}
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Row().style(equal_height=False):
@@ -1187,11 +1137,10 @@ def create_ui(wrap_gradio_gpu_call):
html = gr.HTML()
generation_info = gr.Textbox(visible=False)
html2 = gr.HTML()
-
with gr.Row():
- pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
- pnginfo_send_to_img2img = gr.Button('Send to img2img')
-
+ buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
+ parameters_copypaste.bind_buttons(buttons, image, generation_info)
+
image.change(
fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image],
@@ -1475,28 +1424,6 @@ def create_ui(wrap_gradio_gpu_call):
script_callbacks.ui_settings_callback()
opts.reorder()
- def open_folder(f):
- if not os.path.exists(f):
- print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
- return
- elif not os.path.isdir(f):
- print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(f)
- if platform.system() == "Windows":
- os.startfile(path)
- elif platform.system() == "Darwin":
- sp.Popen(["open", path])
- else:
- sp.Popen(["xdg-open", path])
-
def run_settings(*args):
changed = 0
@@ -1641,6 +1568,8 @@ Requested path was: {f}
if column is not None:
column.__exit__()
+ parameters_copypaste.run_bind()
+
interfaces = [
(txt2img_interface, "txt2img", "txt2img"),
(img2img_interface, "img2img", "img2img"),
@@ -1731,85 +1660,14 @@ Requested path was: {f}
component_dict['sd_model_checkpoint'],
]
)
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
- txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
- img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
- send_to_img2img.click(
- fn=lambda img, *args: (image_from_url_text(img),*args),
- _js="(gallery, ...args) => [extract_image_from_gallery_img2img(gallery), ...args]",
- inputs=[txt2img_gallery] + txt2img_fields,
- outputs=[init_img] + img2img_fields,
- )
-
- send_to_inpaint.click(
- fn=lambda x, *args: (image_from_url_text(x), *args),
- _js="(gallery, ...args) => [extract_image_from_gallery_inpaint(gallery), ...args]",
- inputs=[txt2img_gallery] + txt2img_fields,
- outputs=[init_img_with_mask] + img2img_fields,
- )
-
- img2img_send_to_img2img.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_img2img",
- inputs=[img2img_gallery],
- outputs=[init_img],
- )
-
- img2img_send_to_inpaint.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_inpaint",
- inputs=[img2img_gallery],
- outputs=[init_img_with_mask],
- )
-
- send_to_extras.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_extras",
- inputs=[txt2img_gallery],
- outputs=[extras_image],
- )
-
- open_txt2img_folder.click(
- fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
- inputs=[],
- outputs=[],
- )
-
- open_img2img_folder.click(
- fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
- inputs=[],
- outputs=[],
- )
-
- open_extras_folder.click(
- fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
- inputs=[],
- outputs=[],
- )
-
- img2img_send_to_extras.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_extras",
- inputs=[img2img_gallery],
- outputs=[extras_image],
- )
+
settings_map = {
'sd_hypernetwork': 'Hypernet',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
-
- settings_paste_fields = [
- (component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
- for k, v in settings_map.items()
- ]
-
- modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
- modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)
-
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
+
ui_config_file = cmd_opts.ui_config_file
ui_settings = {}
--
cgit v1.2.3
From 85fcccc105aa50f1d78de559233eaa9f384608b5 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Wed, 26 Oct 2022 22:24:33 +0900
Subject: Squashed commit of fixing dropout silently
fix dropouts for future hypernetworks
add kwargs for Hypernetwork class
hypernet UI for gradio input
add recommended options
remove as options
revert adding options in ui
---
modules/hypernetworks/hypernetwork.py | 25 +++++++++++++++++--------
modules/ui.py | 4 ++--
2 files changed, 19 insertions(+), 10 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 2997cead..dd921153 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -34,7 +34,8 @@ class HypernetworkModule(torch.nn.Module):
}
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
- def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False, activate_output=False):
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
+ add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -60,7 +61,7 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout except last layer
- if use_dropout and i < len(layer_structure) - 3:
+ if 'last_layer_dropout' in kwargs and kwargs['last_layer_dropout'] and use_dropout and i < len(layer_structure) - 2:
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
@@ -126,7 +127,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
self.filename = None
self.name = name
self.layers = {}
@@ -139,11 +140,14 @@ class Hypernetwork:
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
self.activate_output = activate_output
+ self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
def weights(self):
@@ -172,7 +176,8 @@ class Hypernetwork:
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output
-
+ state_dict['last_layer_dropout'] = self.last_layer_dropout
+
torch.save(state_dict, filename)
def load(self, filename):
@@ -193,12 +198,16 @@ class Hypernetwork:
self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
+ print(f"Activate last layer is set to {self.activate_output}")
+ self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
- HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
self.name = state_dict.get('name', self.name)
diff --git a/modules/ui.py b/modules/ui.py
index 0a63e357..55cbe859 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1238,8 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
- new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
+ new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Normal is default, for experiments, relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
--
cgit v1.2.3
From cc56df996e95c2c82295ab7b9928da2544791220 Mon Sep 17 00:00:00 2001
From: guaneec
Date: Wed, 26 Oct 2022 23:51:51 +0800
Subject: Fix dropout logic
---
modules/hypernetworks/hypernetwork.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index dd921153..b17598fe 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -35,7 +35,7 @@ class HypernetworkModule(torch.nn.Module):
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
- add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
+ add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -61,7 +61,7 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout except last layer
- if 'last_layer_dropout' in kwargs and kwargs['last_layer_dropout'] and use_dropout and i < len(layer_structure) - 2:
+ if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
--
cgit v1.2.3
From 029d7c75436558f1e884bb127caed73caaecb83a Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Thu, 27 Oct 2022 14:44:53 +0900
Subject: Revert unresolved changes in Bias initialization
it should be zeros_ or parameterized in future properly.
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b17598fe..25427a37 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -75,7 +75,7 @@ class HypernetworkModule(torch.nn.Module):
w, b = layer.weight.data, layer.bias.data
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
normal_(w, mean=0.0, std=0.01)
- normal_(b, mean=0.0, std=0.005)
+ normal_(b, mean=0.0, std=0)
elif weight_init == 'XavierUniform':
xavier_uniform_(w)
zeros_(b)
--
cgit v1.2.3
From 462e6ba6675bd14c0f82e465423a0eedfff82372 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Thu, 27 Oct 2022 15:40:24 +0900
Subject: Disable unavailable or duplicate options
---
modules/hypernetworks/ui.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 2c6c0470..c2d4b51c 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -8,7 +8,8 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
-keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
+not_available = ["hardswish", "multiheadattention"]
+keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
--
cgit v1.2.3
From e0cbf53f451f45ea73dafab654eb6596cbd67ec2 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Thu, 27 Oct 2022 18:00:51 +0800
Subject: create send to buttons by extensions
---
modules/ui.py | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index ccba14b6..922a2163 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1568,7 +1568,8 @@ def create_ui(wrap_gradio_gpu_call):
if column is not None:
column.__exit__()
- parameters_copypaste.run_bind()
+
+
interfaces = [
(txt2img_interface, "txt2img", "txt2img"),
@@ -1581,7 +1582,7 @@ def create_ui(wrap_gradio_gpu_call):
interfaces += script_callbacks.ui_tabs_callback()
- interfaces += [(settings_interface, "Settings", "settings")]
+ interfaces += [(settings_interface, "Settings", "settings")]
css = ""
@@ -1667,7 +1668,8 @@ def create_ui(wrap_gradio_gpu_call):
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
-
+
+ parameters_copypaste.run_bind()
ui_config_file = cmd_opts.ui_config_file
ui_settings = {}
--
cgit v1.2.3
From c4b5ca5778340b21288d84dfb8fe1d5773c886a8 Mon Sep 17 00:00:00 2001
From: Yuta Hayashibe
Date: Thu, 27 Oct 2022 22:00:28 +0900
Subject: Truncate too long filename
---
modules/images.py | 16 ++++++++++++----
1 file changed, 12 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index 7870b5b7..42363ed3 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -416,6 +416,14 @@ def get_next_sequence_number(path, basename):
return result + 1
+def truncate_fullpath(full_path, encoding='utf-8'):
+ dir_name, full_name = os.path.split(full_path)
+ file_name, file_ext = os.path.splitext(full_name)
+ max_length = os.statvfs(dir_name).f_namemax
+ file_name_truncated = file_name.encode(encoding)[:max_length - len(file_ext)].decode(encoding, 'ignore')
+ return os.path.join(dir_name , file_name_truncated + file_ext)
+
+
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
"""Save an image.
@@ -456,7 +464,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if save_to_dirs:
dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
- path = os.path.join(path, dirname)
+ path = truncate_fullpath(os.path.join(path, dirname))
os.makedirs(path, exist_ok=True)
@@ -480,13 +488,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
fullfn = None
for i in range(500):
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
- fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
+ fullfn = truncate_fullpath(os.path.join(path, f"{fn}{file_decoration}.{extension}"))
if not os.path.exists(fullfn):
break
else:
- fullfn = os.path.join(path, f"{file_decoration}.{extension}")
+ fullfn = truncate_fullpath(os.path.join(path, f"{file_decoration}.{extension}"))
else:
- fullfn = os.path.join(path, f"{forced_filename}.{extension}")
+ fullfn = truncate_fullpath(os.path.join(path, f"{forced_filename}.{extension}"))
pnginfo = existing_info or {}
if info is not None:
--
cgit v1.2.3
From 0995e879cea8ce871489ea8e393bb0eba6edc09c Mon Sep 17 00:00:00 2001
From: Florian Horn
Date: Thu, 27 Oct 2022 16:20:01 +0200
Subject: added save button and shortcut (s) to Modal View
---
modules/ui.py | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 0a63e357..1332e265 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -630,7 +630,7 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
-
+
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
@@ -683,7 +683,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
with gr.Row():
- save = gr.Button('Save')
+ saveButtonId = 'save_txt2img'
+ save = gr.Button('Save', elem_id=saveButtonId)
send_to_img2img = gr.Button('Send to img2img')
send_to_inpaint = gr.Button('Send to inpaint')
send_to_extras = gr.Button('Send to extras')
@@ -901,7 +902,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
with gr.Row():
- save = gr.Button('Save')
+ saveButtonId = 'save_img2img'
+ save = gr.Button('Save', elem_id=saveButtonId)
img2img_send_to_img2img = gr.Button('Send to img2img')
img2img_send_to_inpaint = gr.Button('Send to inpaint')
img2img_send_to_extras = gr.Button('Send to extras')
--
cgit v1.2.3
From 268159cfe3231743c554a1a9bf15d090c758f920 Mon Sep 17 00:00:00 2001
From: Florian Horn
Date: Thu, 27 Oct 2022 16:32:10 +0200
Subject: fixed indentation
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 1332e265..d49b10b2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -630,7 +630,7 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
-
+
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
--
cgit v1.2.3
From a0a7024c679056dd66beb1832e52041b10143130 Mon Sep 17 00:00:00 2001
From: FlameLaw <116745066+FlameLaw@users.noreply.github.com>
Date: Fri, 28 Oct 2022 02:13:48 +0900
Subject: Fix random dataset shuffle on TI
---
modules/textual_inversion/dataset.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 5b1c5002..8bb00d27 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -86,12 +86,12 @@ class PersonalizedBase(Dataset):
assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
- self.initial_indexes = np.arange(len(self.dataset))
+ self.dataset_length = len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
+ self.indexes = np.random.permutation(self.dataset_length)
def create_text(self, filename_text):
text = random.choice(self.lines)
--
cgit v1.2.3
From 26a3fd2fe9314330336fb0e28d1e9ca7d2abe10e Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Thu, 27 Oct 2022 11:27:59 -0700
Subject: Highres fix works with unmasked latent. Also refactor the mask
creation to make it more accesible.
---
modules/processing.py | 134 ++++++++++++++++++++++++++++----------------------
1 file changed, 76 insertions(+), 58 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index f72185ac..548eec29 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -129,6 +129,73 @@ class StableDiffusionProcessing():
self.all_seeds = None
self.all_subseeds = None
+ def txt2img_image_conditioning(self, x, width=None, height=None):
+ if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return torch.zeros(
+ x.shape[0], 5, 1, 1,
+ dtype=x.dtype,
+ device=x.device
+ )
+
+ height = height or self.height
+ width = width or self.width
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
+ if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ return torch.zeros(
+ latent_image.shape[0], 5, 1, 1,
+ dtype=latent_image.dtype,
+ device=latent_image.device
+ )
+
+ # Handle the different mask inputs
+ if image_mask is not None:
+ if torch.is_tensor(image_mask):
+ conditioning_mask = image_mask
+ else:
+ conditioning_mask = np.array(image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
+ conditioning_mask = conditioning_mask.to(source_image.device)
+ conditioning_image = torch.lerp(
+ source_image,
+ source_image * (1.0 - conditioning_mask),
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
+ )
+
+ # Encode the new masked image using first stage of network.
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
+
+ return image_conditioning
+
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -571,37 +638,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def create_dummy_mask(self, x, width=None, height=None):
- if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- else:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
-
- return image_conditioning
-
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
@@ -638,7 +684,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
+ image_conditioning = self.img2img_image_conditioning(
+ decoded_samples,
+ samples,
+ decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
+ )
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples
@@ -770,40 +821,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- if self.image_mask is not None:
- conditioning_mask = np.array(self.image_mask.convert("L"))
- conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
- conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
-
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
- else:
- conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
-
- # Create another latent image, this time with a masked version of the original input.
- conditioning_mask = conditioning_mask.to(image.device)
-
- # Smoothly interpolate between the masked and unmasked latent conditioning image.
- conditioning_image = torch.lerp(
- image,
- image * (1.0 - conditioning_mask),
- getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
- )
-
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
-
- # Create the concatenated conditioning tensor to be fed to `c_concat`
- conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
- conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
- self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
- self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
- else:
- self.image_conditioning = torch.zeros(
- self.init_latent.shape[0], 5, 1, 1,
- dtype=self.init_latent.dtype,
- device=self.init_latent.device
- )
+ self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
--
cgit v1.2.3
From a38496c1deef12f56f74f8abce2034bef8bdaccb Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Thu, 27 Oct 2022 11:31:31 -0700
Subject: Moved mask weight config to SD section
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index d47378e8..9c2fa0d4 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -267,6 +267,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
+ "inpainting_mask_weight": OptionInfo(1.0, "Strength of img2img conditioning mask for inpainting models.", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
@@ -320,7 +321,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
- "inpainting_mask_weight": OptionInfo(1.0, "Blend betweeen an unmasked and masked conditioning image for inpainting models.", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
}))
--
cgit v1.2.3
From b68c7c437eda2840a304539dd2acd0b0894e920c Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Thu, 27 Oct 2022 11:45:35 -0700
Subject: Updated name and hover text.
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 9c2fa0d4..7c428d90 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -267,7 +267,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
- "inpainting_mask_weight": OptionInfo(1.0, "Strength of img2img conditioning mask for inpainting models.", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
--
cgit v1.2.3
From bdc90837987ed8919dd611fd01553b0c170ded5c Mon Sep 17 00:00:00 2001
From: Roy Shilkrot
Date: Thu, 27 Oct 2022 15:20:15 -0400
Subject: Add a barebones interrogate API
---
modules/api/api.py | 25 ++++++++++++++++++++++++-
modules/api/models.py | 13 ++++++++++++-
2 files changed, 36 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 6e9d6097..eabdb7b8 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,4 +1,4 @@
-from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
+from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI, InterrogateAPI
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
@@ -25,6 +25,11 @@ class ImageToImageResponse(BaseModel):
parameters: Json
info: Json
+class InterrogateResponse(BaseModel):
+ caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
+ parameters: Json
+ info: Json
+
class Api:
def __init__(self, app, queue_lock):
@@ -33,6 +38,7 @@ class Api:
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix
@@ -118,6 +124,23 @@ class Api:
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
+ def interrogateapi(self, interrogatereq: InterrogateAPI):
+ image_b64 = interrogatereq.image
+ if image_b64 is None:
+ raise HTTPException(status_code=404, detail="Image not found")
+
+ populate = interrogatereq.copy(update={ # Override __init__ params
+ }
+ )
+
+ img = self.__base64_to_image(image_b64)
+
+ # Override object param
+ with self.queue_lock:
+ processed = shared.interrogator.interrogate(img)
+
+ return InterrogateResponse(caption=processed, parameters=json.dumps(vars(interrogatereq)), info=None)
+
def extrasapi(self):
raise NotImplementedError
diff --git a/modules/api/models.py b/modules/api/models.py
index 079e33d9..8be64749 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -63,7 +63,12 @@ class PydanticModelGenerator:
self._model_name = model_name
- self._class_data = merge_class_params(class_instance)
+
+ if class_instance is not None:
+ self._class_data = merge_class_params(class_instance)
+ else:
+ self._class_data = {}
+
self._model_def = [
ModelDef(
field=underscore(k),
@@ -105,4 +110,10 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
+).generate_model()
+
+InterrogateAPI = PydanticModelGenerator(
+ "Interrogate",
+ None,
+ [{"key": "image", "type": str, "default": None}]
).generate_model()
\ No newline at end of file
--
cgit v1.2.3
From b50ff4f4e4d4d6bf31e222832d3fe4cfde4703c9 Mon Sep 17 00:00:00 2001
From: Josh Watzman
Date: Thu, 27 Oct 2022 21:59:16 +0100
Subject: Reduce peak memory usage when changing models
A few tweaks to reduce peak memory usage, the biggest being that if we
aren't using the checkpoint cache, we shouldn't duplicate the model
state dict just to immediately throw it away.
On my machine with 16GB of RAM, this change means I can typically change
models, whereas before it would typically OOM.
---
modules/sd_models.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index e697bb72..203e99a8 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -170,7 +170,9 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
- missing, extra = model.load_state_dict(sd, strict=False)
+ del pl_sd
+ model.load_state_dict(sd, strict=False)
+ del sd
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
@@ -194,9 +196,10 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.to(devices.dtype_vae)
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
- checkpoints_loaded.popitem(last=False) # LRU
+ if shared.opts.sd_checkpoint_cache > 0:
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ checkpoints_loaded.popitem(last=False) # LRU
else:
print(f"Loading weights [{sd_model_hash}] from cache")
checkpoints_loaded.move_to_end(checkpoint_info)
--
cgit v1.2.3
From 5d5dc64064d8ca399a76fe44dbb62bdef6c4b7c4 Mon Sep 17 00:00:00 2001
From: Antonio
Date: Fri, 28 Oct 2022 05:49:39 +0200
Subject: Natural sorting for dropdown checkpoint list
Example:
Before After
11.ckpt 11.ckpt
ab.ckpt ab.ckpt
ade_pablo_step_1000.ckpt ade_pablo_step_500.ckpt
ade_pablo_step_500.ckpt ade_pablo_step_1000.ckpt
ade_step_1000.ckpt ade_step_500.ckpt
ade_step_1500.ckpt ade_step_1000.ckpt
ade_step_2000.ckpt ade_step_1500.ckpt
ade_step_2500.ckpt ade_step_2000.ckpt
ade_step_3000.ckpt ade_step_2500.ckpt
ade_step_500.ckpt ade_step_3000.ckpt
atp_step_5500.ckpt atp_step_5500.ckpt
model1.ckpt model1.ckpt
model10.ckpt model10.ckpt
model1000.ckpt model33.ckpt
model33.ckpt model50.ckpt
model400.ckpt model400.ckpt
model50.ckpt model1000.ckpt
moo44.ckpt moo44.ckpt
v1-4-pruned-emaonly.ckpt v1-4-pruned-emaonly.ckpt
v1-5-pruned-emaonly.ckpt v1-5-pruned-emaonly.ckpt
v1-5-pruned.ckpt v1-5-pruned.ckpt
v1-5-vae.ckpt v1-5-vae.ckpt
---
modules/sd_models.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index e697bb72..64d5ee0d 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -3,6 +3,7 @@ import os.path
import sys
from collections import namedtuple
import torch
+import re
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
@@ -35,8 +36,10 @@ def setup_model():
list_models()
-def checkpoint_tiles():
- return sorted([x.title for x in checkpoints_list.values()])
+def checkpoint_tiles():
+ convert = lambda name: int(name) if name.isdigit() else name.lower()
+ alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
+ return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
def list_models():
--
cgit v1.2.3
From b2a8b263b2f09bd772f75502c5a83656580f34ec Mon Sep 17 00:00:00 2001
From: benkyoujouzu
Date: Thu, 27 Oct 2022 13:00:47 +0800
Subject: Add missing support for linear activation in hypernetwork
---
modules/hypernetworks/hypernetwork.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 8113b35b..87cf3cf3 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -25,6 +25,7 @@ from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
+ "linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
--
cgit v1.2.3
From 9e465c8aa5616df4c6723bee007ffd3910404f12 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Thu, 27 Oct 2022 23:03:34 -0700
Subject: Add strength to textinfo.
---
modules/processing.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 4efba946..93066522 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -329,6 +329,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
--
cgit v1.2.3
From d4a069a23cb19104b4e58a33d0d1670fadaefb7a Mon Sep 17 00:00:00 2001
From: timntorres
Date: Thu, 27 Oct 2022 23:16:27 -0700
Subject: Read hypernet strength from PNG info.
---
modules/ui.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 0a63e357..62a2f4f3 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1812,6 +1812,7 @@ Requested path was: {f}
settings_map = {
'sd_hypernetwork': 'Hypernet',
+ 'sd_hypernetwork_strength': 'Hypernetwork strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
--
cgit v1.2.3
From c0677b33161f04c3ed1a7a78f4c7288fb95787b7 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Thu, 27 Oct 2022 23:31:45 -0700
Subject: Explicitly state when Hypernet is none.
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 93066522..74a0cd64 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -328,7 +328,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernet": ("None" if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
--
cgit v1.2.3
From db5a354c489bfd1c95e0bbf9af12ab8b5d6fe170 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Fri, 28 Oct 2022 01:41:57 -0700
Subject: Always ignore "None.pt" in the hypernet directory.
---
modules/hypernetworks/hypernetwork.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 8113b35b..cd920df5 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -208,13 +208,16 @@ def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
- res[name] = filename
+ # Prevent a hypothetical "None.pt" from being listed.
+ if name != "None":
+ res[name] = filename
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
- if path is not None:
+ # Prevent any file named "None.pt" from being loaded.
+ if path is not None and filename != "None":
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork()
--
cgit v1.2.3
From 9ceef81f77ecce89f0c8f412c4d849210d852e82 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Fri, 28 Oct 2022 20:48:08 +0700
Subject: Fix log off by 1
---
modules/hypernetworks/hypernetwork.py | 12 +++++++-----
modules/textual_inversion/learn_schedule.py | 2 +-
modules/textual_inversion/textual_inversion.py | 24 ++++++++++++------------
3 files changed, 20 insertions(+), 18 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 8113b35b..a0297997 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -428,7 +428,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
optimizer.step()
- if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
+ steps_done = hypernetwork.step + 1
+
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
if len(previous_mean_losses) > 1:
@@ -438,9 +440,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
- if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
- hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
+ hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
@@ -449,8 +451,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
"learn_rate": scheduler.learn_rate
})
- if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 2062726a..3a736065 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -52,7 +52,7 @@ class LearnRateScheduler:
self.finished = False
def apply(self, optimizer, step_number):
- if step_number <= self.end_step:
+ if step_number < self.end_step:
return
try:
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ff002d3e..17dfb223 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -184,9 +184,8 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
- if step % shared.opts.training_write_csv_every != 0:
+ if (step + 1) % shared.opts.training_write_csv_every != 0:
return
-
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
@@ -196,11 +195,11 @@ def write_loss(log_directory, filename, step, epoch_len, values):
csv_writer.writeheader()
epoch = step // epoch_len
- epoch_step = step - epoch * epoch_len
+ epoch_step = step % epoch_len
csv_writer.writerow({
"step": step + 1,
- "epoch": epoch + 1,
+ "epoch": epoch,
"epoch_step": epoch_step + 1,
**values,
})
@@ -282,15 +281,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward()
optimizer.step()
+ steps_done = embedding.step + 1
epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step - (epoch_num * len(ds)) + 1
+ epoch_step = embedding.step % len(ds)
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
- if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{embedding.step}'
+ embedding.name = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
@@ -300,8 +300,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
"learn_rate": scheduler.learn_rate
})
- if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
- forced_filename = f'{embedding_name}-{embedding.step}'
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
@@ -334,7 +334,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
@@ -350,7 +350,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, embedding.step)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
--
cgit v1.2.3
From 26d08193848568b06105a1ee7b76f338ebf0f0ee Mon Sep 17 00:00:00 2001
From: Chris OBryan <13701027+cobryan05@users.noreply.github.com>
Date: Fri, 28 Oct 2022 13:24:11 -0500
Subject: extras: Add option to run upscaling before face fixing
Face restoration can look much better if ran after upscaling, as it
allows the restoration to fix upscaling artifacts. This patch adds
an option to choose which order to run upscaling/face fixing in.
---
modules/extras.py | 145 +++++++++++++++++++++++++++++++++++-------------------
modules/ui.py | 4 ++
2 files changed, 99 insertions(+), 50 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index 22c5a1c1..79047f3a 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -7,6 +7,10 @@ from PIL import Image
import torch
import tqdm
+from typing import Callable, List, Tuple
+from functools import partial
+from dataclasses import dataclass
+
from modules import processing, shared, images, devices, sd_models
from modules.shared import opts
import modules.gfpgan_model
@@ -20,7 +24,7 @@ import gradio as gr
cached_images = {}
-def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
+def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ):
devices.torch_gc()
imageArr = []
@@ -56,68 +60,109 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else:
outpath = opts.outdir_samples or opts.outdir_extras_samples
-
- for image, image_name in zip(imageArr, imageNameArr):
- if image is None:
- return outputs, "Please select an input image.", ''
- existing_pnginfo = image.info or {}
- image = image.convert("RGB")
- info = ""
+ # Extra operation definitions
+ def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
+ res = Image.fromarray(restored_img)
+
+ if gfpgan_visibility < 1.0:
+ res = Image.blend(image, res, gfpgan_visibility)
+
+ info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
+ return (res, info)
- if gfpgan_visibility > 0:
- restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
- res = Image.fromarray(restored_img)
+ def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
+ res = Image.fromarray(restored_img)
- if gfpgan_visibility < 1.0:
- res = Image.blend(image, res, gfpgan_visibility)
+ if codeformer_visibility < 1.0:
+ res = Image.blend(image, res, codeformer_visibility)
- info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
- image = res
+ info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
+ return (res, info)
- if codeformer_visibility > 0:
- restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
- res = Image.fromarray(restored_img)
- if codeformer_visibility < 1.0:
- res = Image.blend(image, res, codeformer_visibility)
+ def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
+ small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
+ pixels = tuple(np.array(small).flatten().tolist())
+ key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
+ resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
- info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
- image = res
+ c = cached_images.get(key)
+ if c is None:
+ upscaler = shared.sd_upscalers[scaler_index]
+ c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
+ if mode == 1 and crop:
+ cropped = Image.new("RGB", (resize_w, resize_h))
+ cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
+ c = cropped
+ cached_images[key] = c
+ return c
+
+ def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
+ nonlocal upscaling_resize
if resize_mode == 1:
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
crop_info = " (crop)" if upscaling_crop else ""
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
+ return (image, info)
+
+ @dataclass
+ class UpscaleParams:
+ upscaler_idx: int
+ blend_alpha: float
+
+ def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ blended_result: Image.Image = None
+ for upscaler in params:
+ res = upscale(image, upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
+ info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
+ if blended_result is None:
+ blended_result = res
+ else:
+ blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
+ return (blended_result, info)
+
+ # Build a list of operations to run
+ facefix_ops: List[Callable] = []
+ if gfpgan_visibility > 0:
+ facefix_ops.append(run_gfpgan)
+ if codeformer_visibility > 0:
+ facefix_ops.append(run_codeformer)
+
+ upscale_ops: List[Callable] = []
+ if resize_mode == 1:
+ upscale_ops.append(run_prepare_crop)
+
+ if upscaling_resize != 0:
+ step_params: List[UpscaleParams] = []
+ step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_1, blend_alpha=1.0 ))
+ if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
+ step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility ) )
+
+ upscale_ops.append( partial(run_upscalers_blend, step_params) )
+
+
+ extras_ops: List[Callable] = []
+ if upscale_first:
+ extras_ops = upscale_ops + facefix_ops
+ else:
+ extras_ops = facefix_ops + upscale_ops
+
+
+ for image, image_name in zip(imageArr, imageNameArr):
+ if image is None:
+ return outputs, "Please select an input image.", ''
+ existing_pnginfo = image.info or {}
- if upscaling_resize != 1.0:
- def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
- small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
- pixels = tuple(np.array(small).flatten().tolist())
- key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
- resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
-
- c = cached_images.get(key)
- if c is None:
- upscaler = shared.sd_upscalers[scaler_index]
- c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
- if mode == 1 and crop:
- cropped = Image.new("RGB", (resize_w, resize_h))
- cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
- c = cropped
- cached_images[key] = c
-
- return c
-
- info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
- res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
-
- if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
- res = Image.blend(res, res2, extras_upscaler_2_visibility)
-
- image = res
+ image = image.convert("RGB")
+ info = ""
+ # Run each operation on each image
+ for op in extras_ops:
+ image, info = op(image, info)
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
diff --git a/modules/ui.py b/modules/ui.py
index 0a63e357..16b6ac49 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1119,6 +1119,9 @@ def create_ui(wrap_gradio_gpu_call):
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
+ with gr.Group():
+ upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
+
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Column(variant='panel'):
@@ -1152,6 +1155,7 @@ def create_ui(wrap_gradio_gpu_call):
extras_upscaler_1,
extras_upscaler_2,
extras_upscaler_2_visibility,
+ upscale_before_face_fix,
],
outputs=[
result_images,
--
cgit v1.2.3
From bde4731f1d3ddf30c46f86c9f6e71e6c0644089d Mon Sep 17 00:00:00 2001
From: Chris OBryan <13701027+cobryan05@users.noreply.github.com>
Date: Fri, 28 Oct 2022 14:30:04 -0500
Subject: extras: Rework image cache
Bit of a refactor to the image cache to make it easier to extend.
Also takes into account the entire image instead of just a cropped portion.
---
modules/extras.py | 52 ++++++++++++++++++++++++++++++++--------------------
1 file changed, 32 insertions(+), 20 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index 79047f3a..cffe0381 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -7,7 +7,7 @@ from PIL import Image
import torch
import tqdm
-from typing import Callable, List, Tuple
+from typing import Callable, Dict, List, Tuple
from functools import partial
from dataclasses import dataclass
@@ -21,7 +21,18 @@ import piexif.helper
import gradio as gr
-cached_images = {}
+@dataclass(frozen=True)
+class CacheKey:
+ image_hash: int
+ info_hash: int
+ args_hash: int
+
+@dataclass
+class CacheEntry:
+ image: Image.Image
+ info: str
+
+cached_images: Dict[CacheKey, CacheEntry] = {}
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ):
@@ -84,22 +95,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
- small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
- pixels = tuple(np.array(small).flatten().tolist())
- key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
- resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
-
- c = cached_images.get(key)
- if c is None:
- upscaler = shared.sd_upscalers[scaler_index]
- c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
- if mode == 1 and crop:
- cropped = Image.new("RGB", (resize_w, resize_h))
- cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
- c = cropped
- cached_images[key] = c
- return c
-
+ upscaler = shared.sd_upscalers[scaler_index]
+ res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
+ if mode == 1 and crop:
+ cropped = Image.new("RGB", (resize_w, resize_h))
+ cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
+ res = cropped
+ return res
def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
@@ -118,8 +120,18 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None
for upscaler in params:
- res = upscale(image, upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
+ upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
+ cache_key = CacheKey( image_hash = hash(np.array(image.getdata()).tobytes()),
+ info_hash = hash(info),
+ args_hash = hash(upscale_args) )
+ cached_entry = cached_images.get(cache_key)
+ if cached_entry is None:
+ res = upscale(image, *upscale_args)
+ info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
+ cached_images[cache_key] = CacheEntry(image=res, info=info)
+ else:
+ res, info = cached_entry.image, cached_entry.info
+
if blended_result is None:
blended_result = res
else:
--
cgit v1.2.3
From 1f1b327959b546b5e6f995905a1699c5fe4a0c35 Mon Sep 17 00:00:00 2001
From: Chris OBryan <13701027+cobryan05@users.noreply.github.com>
Date: Fri, 28 Oct 2022 16:11:16 -0500
Subject: extras: Make image cache LRU
This changes the extras image cache into a Least-Recently-Used
cache. This allows more experimentation with different upscalers
without missing the cache.
Max cache size is increased to 5 and is cleared on source image
update.
---
modules/extras.py | 67 +++++++++++++++++++++++++++++++------------------------
modules/ui.py | 5 +++++
2 files changed, 43 insertions(+), 29 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index cffe0381..72cc6d1d 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
import math
import os
@@ -7,7 +8,7 @@ from PIL import Image
import torch
import tqdm
-from typing import Callable, Dict, List, Tuple
+from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass
@@ -21,18 +22,34 @@ import piexif.helper
import gradio as gr
-@dataclass(frozen=True)
-class CacheKey:
- image_hash: int
- info_hash: int
- args_hash: int
+class LruCache(OrderedDict):
+ @dataclass(frozen=True)
+ class Key:
+ image_hash: int
+ info_hash: int
+ args_hash: int
-@dataclass
-class CacheEntry:
- image: Image.Image
- info: str
+ @dataclass
+ class Value:
+ image: Image.Image
+ info: str
+
+ def __init__(self, max_size:int = 5, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._max_size = max_size
+
+ def get(self, key: LruCache.Key) -> LruCache.Value:
+ ret = super().get(key)
+ if ret is not None:
+ self.move_to_end(key) # Move to end of eviction list
+ return ret
+
+ def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
+ self[key] = value
+ while len(self) > self._max_size:
+ self.popitem(last=False)
-cached_images: Dict[CacheKey, CacheEntry] = {}
+cached_images: LruCache = LruCache(max_size = 5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ):
@@ -121,14 +138,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
blended_result: Image.Image = None
for upscaler in params:
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- cache_key = CacheKey( image_hash = hash(np.array(image.getdata()).tobytes()),
+ cache_key = LruCache.Key( image_hash = hash(np.array(image.getdata()).tobytes()),
info_hash = hash(info),
- args_hash = hash(upscale_args) )
+ args_hash = hash(upscale_args + (upscaler.blend_alpha,)) )
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
- cached_images[cache_key] = CacheEntry(image=res, info=info)
+ cached_images.put(cache_key, LruCache.Value(image=res, info=info))
else:
res, info = cached_entry.image, cached_entry.info
@@ -140,14 +157,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Build a list of operations to run
facefix_ops: List[Callable] = []
- if gfpgan_visibility > 0:
- facefix_ops.append(run_gfpgan)
- if codeformer_visibility > 0:
- facefix_ops.append(run_codeformer)
+ facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
+ facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
upscale_ops: List[Callable] = []
- if resize_mode == 1:
- upscale_ops.append(run_prepare_crop)
+ upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
if upscaling_resize != 0:
step_params: List[UpscaleParams] = []
@@ -157,12 +171,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
upscale_ops.append( partial(run_upscalers_blend, step_params) )
-
- extras_ops: List[Callable] = []
- if upscale_first:
- extras_ops = upscale_ops + facefix_ops
- else:
- extras_ops = facefix_ops + upscale_ops
+ extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
for image, image_name in zip(imageArr, imageNameArr):
@@ -176,9 +185,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for op in extras_ops:
image, info = op(image, info)
- while len(cached_images) > 2:
- del cached_images[next(iter(cached_images.keys()))]
-
if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
@@ -198,6 +204,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return outputs, plaintext_to_html(info), ''
+def clear_cache():
+ cached_images.clear()
+
def run_pnginfo(image):
if image is None:
diff --git a/modules/ui.py b/modules/ui.py
index 16b6ac49..b7c36c55 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1178,6 +1178,11 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[init_img_with_mask],
)
+ extras_image.change(
+ fn=modules.extras.clear_cache,
+ inputs=[], outputs=[]
+ )
+
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
--
cgit v1.2.3
From 5732c0282d529ef2e0591c76e16959e97240dad8 Mon Sep 17 00:00:00 2001
From: Chris OBryan <13701027+cobryan05@users.noreply.github.com>
Date: Fri, 28 Oct 2022 16:36:25 -0500
Subject: extras-tweaks: autoformat changed lines
---
modules/extras.py | 30 +++++++++++++++---------------
1 file changed, 15 insertions(+), 15 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index 72cc6d1d..50026a25 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -34,14 +34,14 @@ class LruCache(OrderedDict):
image: Image.Image
info: str
- def __init__(self, max_size:int = 5, *args, **kwargs):
+ def __init__(self, max_size: int = 5, *args, **kwargs):
super().__init__(*args, **kwargs)
self._max_size = max_size
def get(self, key: LruCache.Key) -> LruCache.Value:
ret = super().get(key)
if ret is not None:
- self.move_to_end(key) # Move to end of eviction list
+ self.move_to_end(key) # Move to end of eviction list
return ret
def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
@@ -49,10 +49,11 @@ class LruCache(OrderedDict):
while len(self) > self._max_size:
self.popitem(last=False)
-cached_images: LruCache = LruCache(max_size = 5)
+cached_images: LruCache = LruCache(max_size=5)
-def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ):
+
+def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool):
devices.torch_gc()
imageArr = []
@@ -88,8 +89,8 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else:
outpath = opts.outdir_samples or opts.outdir_extras_samples
-
# Extra operation definitions
+
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
@@ -110,7 +111,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
return (res, info)
-
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
@@ -134,13 +134,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
upscaler_idx: int
blend_alpha: float
- def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None
for upscaler in params:
- upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- cache_key = LruCache.Key( image_hash = hash(np.array(image.getdata()).tobytes()),
- info_hash = hash(info),
- args_hash = hash(upscale_args + (upscaler.blend_alpha,)) )
+ upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
+ upscaling_resize_w, upscaling_resize_h, upscaling_crop)
+ cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
+ info_hash=hash(info),
+ args_hash=hash(upscale_args + (upscaler.blend_alpha,)))
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
@@ -165,15 +166,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if upscaling_resize != 0:
step_params: List[UpscaleParams] = []
- step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_1, blend_alpha=1.0 ))
+ step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility ) )
+ step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
- upscale_ops.append( partial(run_upscalers_blend, step_params) )
+ upscale_ops.append(partial(run_upscalers_blend, step_params))
extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
-
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
--
cgit v1.2.3
From d8b366146748555a18b595af400c8cb222ea0ec9 Mon Sep 17 00:00:00 2001
From: Chris OBryan <13701027+cobryan05@users.noreply.github.com>
Date: Fri, 28 Oct 2022 16:55:02 -0500
Subject: extras: upscaler blending should not be considered in cache key
---
modules/extras.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index 50026a25..681d8d5a 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
info_hash=hash(info),
- args_hash=hash(upscale_args + (upscaler.blend_alpha,)))
+ args_hash=hash(upscale_args))
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
--
cgit v1.2.3
From 539c0f51e436beeb0ca2b8b8d52b24f4b59ad56a Mon Sep 17 00:00:00 2001
From: Yaiol <38218161+Yaiol@users.noreply.github.com>
Date: Sat, 29 Oct 2022 01:07:01 +0200
Subject: Update images.py
Filename tags [height] and [width] are wrongly referencing to process size instead of resulting image size. Making all upscale files named wrongly.
---
modules/images.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index 7870b5b7..a0728553 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -300,8 +300,8 @@ class FilenameGenerator:
'seed': lambda self: self.seed if self.seed is not None else '',
'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale,
- 'width': lambda self: self.p and self.p.width,
- 'height': lambda self: self.p and self.p.height,
+ 'width': lambda self: self.image.width,
+ 'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
@@ -315,10 +315,11 @@ class FilenameGenerator:
}
default_time_format = '%Y%m%d%H%M%S'
- def __init__(self, p, seed, prompt):
+ def __init__(self, p, seed, prompt, image):
self.p = p
self.seed = seed
self.prompt = prompt
+ self.image = image
def prompt_no_style(self):
if self.p is None or self.prompt is None:
@@ -449,7 +450,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
txt_fullfn (`str` or None):
If a text file is saved for this image, this will be its full path. Otherwise None.
"""
- namegen = FilenameGenerator(p, seed, prompt)
+ namegen = FilenameGenerator(p, seed, prompt, image)
if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
--
cgit v1.2.3
From f361e804ebaa5af4a10711ece2522869fb64a4c6 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sat, 29 Oct 2022 08:36:50 +0900
Subject: Re enable linear
---
modules/hypernetworks/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index c2d4b51c..aad09ffc 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -9,7 +9,7 @@ from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
not_available = ["hardswish", "multiheadattention"]
-keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
--
cgit v1.2.3
From bce5adcd6de1ad608df8d813a92e1167b7a1d3f2 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 29 Oct 2022 07:37:06 +0300
Subject: change default hypernet activation function to linear
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 0a63e357..2541970d 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1238,7 +1238,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
--
cgit v1.2.3
From a1e5e0d7669def010ecf31d801d6f0667bcf8061 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 29 Oct 2022 08:11:03 +0300
Subject: skip filenames starting with . for img2img and extras batch modes
---
modules/extras.py | 2 +-
modules/img2img.py | 2 +-
modules/shared.py | 5 +++++
3 files changed, 7 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index 681d8d5a..4d51088b 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -72,7 +72,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '':
return outputs, "Please select an input directory.", ''
- image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
+ image_list = shared.listfiles(input_dir)
for img in image_list:
try:
image = Image.open(img)
diff --git a/modules/img2img.py b/modules/img2img.py
index 9c0cf23e..efda26e1 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -19,7 +19,7 @@ import modules.scripts
def process_batch(p, input_dir, output_dir, args):
processing.fix_seed(p)
- images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
+ images = shared.listfiles(input_dir)
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
diff --git a/modules/shared.py b/modules/shared.py
index 7c428d90..7e634423 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -450,3 +450,8 @@ total_tqdm = TotalTQDM()
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start()
+
+
+def listfiles(dirname):
+ filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
+ return [file for file in filenames if os.path.isfile(file)]
--
cgit v1.2.3
From 2d220afb24bd9812d5124814f670ec2a1ff5b0fe Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 29 Oct 2022 08:26:12 +0300
Subject: fix open folder button not working
---
modules/ui.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 922a2163..20cc10cf 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -631,9 +631,9 @@ Requested path was: {f}
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
- open_folder = gr.Button(folder_symbol, elem_id=button_id)
+ open_folder_button = gr.Button(folder_symbol, elem_id=button_id)
- open_folder.click(
+ open_folder_button.click(
fn=lambda: open_folder(opts.outdir_samples or outdir),
inputs=[],
outputs=[],
--
cgit v1.2.3
From a33d0a9a65189be35038be5765f2b4d1b19bab0f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 29 Oct 2022 08:28:48 +0300
Subject: remove weird spaces added to ui.py over time
---
modules/ui.py | 51 +++++++++++++++++++++++++--------------------------
1 file changed, 25 insertions(+), 26 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 20cc10cf..46657dd4 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -596,7 +596,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
)
return refresh_button
-def create_output_panel(tabname, outdir):
+def create_output_panel(tabname, outdir):
def open_folder(f):
if not os.path.exists(f):
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
@@ -618,11 +618,11 @@ Requested path was: {f}
sp.Popen(["open", path])
else:
sp.Popen(["xdg-open", path])
-
- with gr.Column(variant='panel'):
- with gr.Group():
+
+ with gr.Column(variant='panel'):
+ with gr.Group():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
-
+
generation_info = None
with gr.Column():
with gr.Row():
@@ -639,7 +639,7 @@ Requested path was: {f}
outputs=[],
)
- if tabname != "extras":
+ if tabname != "extras":
with gr.Row():
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
@@ -671,8 +671,7 @@ Requested path was: {f}
html_info = gr.HTML()
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
-
-
+
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
@@ -723,10 +722,10 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
-
- txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
-
+
+ txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -781,7 +780,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=lambda x: gr_show(x),
inputs=[enable_hr],
outputs=[hr_options],
- )
+ )
roll.click(
fn=roll_artist,
@@ -902,7 +901,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
- img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
+ img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -990,7 +989,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[init_img],
outputs=[img2img_prompt],
)
-
+
roll.click(
fn=roll_artist,
@@ -1045,7 +1044,7 @@ def create_ui(wrap_gradio_gpu_call):
(denoising_strength, "Denoising strength"),
*modules.scripts.scripts_img2img.infotext_fields
]
- parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
+ parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
@@ -1077,9 +1076,9 @@ def create_ui(wrap_gradio_gpu_call):
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
-
+
with gr.Group():
- extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
+ extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
@@ -1125,7 +1124,7 @@ def create_ui(wrap_gradio_gpu_call):
html_info,
]
)
- parameters_copypaste.add_paste_fields("extras", extras_image, None)
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
@@ -1139,14 +1138,14 @@ def create_ui(wrap_gradio_gpu_call):
html2 = gr.HTML()
with gr.Row():
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
- parameters_copypaste.bind_buttons(buttons, image, generation_info)
-
+ parameters_copypaste.bind_buttons(buttons, image, generation_info)
+
image.change(
fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image],
outputs=[html, generation_info, html2],
)
-
+
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -1569,7 +1568,7 @@ def create_ui(wrap_gradio_gpu_call):
column.__exit__()
-
+
interfaces = [
(txt2img_interface, "txt2img", "txt2img"),
@@ -1582,7 +1581,7 @@ def create_ui(wrap_gradio_gpu_call):
interfaces += script_callbacks.ui_tabs_callback()
- interfaces += [(settings_interface, "Settings", "settings")]
+ interfaces += [(settings_interface, "Settings", "settings")]
css = ""
@@ -1661,7 +1660,7 @@ def create_ui(wrap_gradio_gpu_call):
component_dict['sd_model_checkpoint'],
]
)
-
+
settings_map = {
'sd_hypernetwork': 'Hypernet',
@@ -1669,7 +1668,7 @@ def create_ui(wrap_gradio_gpu_call):
'sd_model_checkpoint': 'Model hash',
}
- parameters_copypaste.run_bind()
+ parameters_copypaste.run_bind()
ui_config_file = cmd_opts.ui_config_file
ui_settings = {}
@@ -1749,7 +1748,7 @@ def load_javascript(raw_response):
javascript = f''
scripts_list = modules.scripts.list_scripts("javascript", ".js")
-
+
for basedir, filename, path in scripts_list:
with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n"
--
cgit v1.2.3
From 3c207ca68483b3406faf519bde2743b578dac222 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 29 Oct 2022 08:42:34 +0300
Subject: add needed imports fr new code in copypaste.py
---
modules/generation_parameters_copypaste.py | 9 +++++++++
modules/ui.py | 7 -------
2 files changed, 9 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 2b80737a..224a17ea 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,3 +1,5 @@
+import base64
+import io
import os
import re
import gradio as gr
@@ -14,6 +16,7 @@ type_of_gr_update = type(gr.update())
paste_fields = {}
bind_list = []
+
def quote(text):
if ',' not in str(text):
return text
@@ -23,6 +26,7 @@ def quote(text):
text = text.replace('"', '\\"')
return f'"{text}"'
+
def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]:
filename = filedata["name"]
@@ -45,19 +49,23 @@ def image_from_url_text(filedata):
image = Image.open(io.BytesIO(filedata))
return image
+
def add_paste_fields(tabname, init_img, fields):
paste_fields[tabname] = {"init_img":init_img, "fields": fields}
+
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
buttons[tab] = gr.Button(f"Send to {tab}")
return buttons
+
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info])
+
def run_bind():
for buttons, send_image, send_generate_info in bind_list:
for tab in buttons:
@@ -98,6 +106,7 @@ def run_bind():
outputs=None,
)
+
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
diff --git a/modules/ui.py b/modules/ui.py
index 46657dd4..280910d0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1,6 +1,4 @@
-import base64
import html
-import io
import json
import math
import mimetypes
@@ -18,13 +16,8 @@ import gradio as gr
import gradio.routes
import gradio.utils
import numpy as np
-import piexif
-import torch
from PIL import Image, PngImagePlugin
-import gradio as gr
-import gradio.utils
-import gradio.routes
from modules import sd_hijack, sd_models, localization, script_callbacks
from modules.paths import script_path
--
cgit v1.2.3
From 2922d8144f677ea3c37189a2b2e3c3d3ff6d3916 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 29 Oct 2022 09:01:04 +0300
Subject: make existing image browser extension not break
---
modules/generation_parameters_copypaste.py | 19 ++++++++++++++-----
1 file changed, 14 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 224a17ea..d590e9ee 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -51,7 +51,14 @@ def image_from_url_text(filedata):
def add_paste_fields(tabname, init_img, fields):
- paste_fields[tabname] = {"init_img":init_img, "fields": fields}
+ paste_fields[tabname] = {"init_img": init_img, "fields": fields}
+
+ # backwards compatibility for existing extensions
+ import modules.ui
+ if tabname == 'txt2img':
+ modules.ui.txt2img_paste_fields = fields
+ elif tabname == 'img2img':
+ modules.ui.img2img_paste_fields = fields
def create_buttons(tabs_list):
@@ -61,7 +68,7 @@ def create_buttons(tabs_list):
return buttons
-#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
+#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info])
@@ -84,12 +91,12 @@ def run_bind():
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
-
+
if send_generate_info and paste_fields[tab]["fields"] is not None:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2']
if shared.opts.send_seed:
paste_field_names += ["Seed"]
- if send_generate_info in paste_fields:
+ if send_generate_info in paste_fields:
button.click(
fn=lambda *x:x,
inputs=[field for field,name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
@@ -154,7 +161,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res
-def connect_paste(button, paste_fields, input_comp):
+
+def connect_paste(button, paste_fields, input_comp, jsfunc=None):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt")
@@ -192,6 +200,7 @@ def connect_paste(button, paste_fields, input_comp):
button.click(
fn=paste_func,
+ _js=jsfunc,
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
)
--
cgit v1.2.3
From 28e6d4a54ea1fa1e34ad1ea0742ab2003ed7fa7f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 29 Oct 2022 09:13:36 +0300
Subject: add element ids for save buttons for #3798
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 66b743f5..3c34eca0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -620,7 +620,7 @@ Requested path was: {f}
with gr.Column():
with gr.Row():
if tabname != "extras":
- save = gr.Button('Save')
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
--
cgit v1.2.3
From beb6fc29798d82f1b08a34cf5dd79e4ab29d4cd0 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 29 Oct 2022 09:57:22 +0300
Subject: move send seed option to UI section and make it false by default
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 5d1ceb85..fb84afd8 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -280,7 +280,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
- "send_seed": OptionInfo(False, "Send seed when sending prompt or image to other interface"),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
@@ -306,6 +305,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
+ "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
--
cgit v1.2.3
From 2c4d20388425a5e40b93eef3722e42e8d375fbb4 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Sat, 29 Oct 2022 00:36:51 -0700
Subject: Revert "Explicitly state when Hypernet is none."
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 377c0978..04fdda7c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -395,7 +395,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": ("None" if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
--
cgit v1.2.3
From 35c45df28b303a05d56a13cb56d4046f08cf8c25 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 29 Oct 2022 10:56:19 +0300
Subject: =?UTF-8?q?fix=20broken=20=E2=86=99=20button,=20fix=20field=20past?=
=?UTF-8?q?e=20ignoring=20most=20of=20useful=20fields=20for=20for=20#3768?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
modules/generation_parameters_copypaste.py | 36 ++++++++++++++++++-------
modules/ui.py | 43 +++++++++++-------------------
2 files changed, 41 insertions(+), 38 deletions(-)
(limited to 'modules')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index d590e9ee..bbaad42e 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -6,7 +6,7 @@ import gradio as gr
from modules.shared import script_path
from modules import shared
import tempfile
-from PIL import Image, PngImagePlugin
+from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
@@ -61,6 +61,24 @@ def add_paste_fields(tabname, init_img, fields):
modules.ui.img2img_paste_fields = fields
+def integrate_settings_paste_fields(component_dict):
+ from modules import ui
+
+ settings_map = {
+ 'sd_hypernetwork': 'Hypernet',
+ 'CLIP_stop_at_last_layers': 'Clip skip',
+ 'sd_model_checkpoint': 'Model hash',
+ }
+ settings_paste_fields = [
+ (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
+ for k, v in settings_map.items()
+ ]
+
+ for tabname, info in paste_fields.items():
+ if info["fields"] is not None:
+ info["fields"] += settings_paste_fields
+
+
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
@@ -87,24 +105,22 @@ def run_bind():
)
else:
button.click(
- fn=lambda x:x,
+ fn=lambda x: x,
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
if send_generate_info and paste_fields[tab]["fields"] is not None:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2']
- if shared.opts.send_seed:
- paste_field_names += ["Seed"]
if send_generate_info in paste_fields:
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else [])
+
button.click(
- fn=lambda *x:x,
- inputs=[field for field,name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
- outputs=[field for field,name in paste_fields[tab]["fields"] if name in paste_field_names],
+ fn=lambda *x: x,
+ inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
+ outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
)
-
else:
- connect_paste(button, [(field, name) for field, name in paste_fields[tab]["fields"] if name in paste_field_names], send_generate_info)
+ connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
button.click(
fn=None,
diff --git a/modules/ui.py b/modules/ui.py
index 3c34eca0..5055ca64 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -589,6 +589,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
)
return refresh_button
+
def create_output_panel(tabname, outdir):
def open_folder(f):
if not os.path.exists(f):
@@ -716,6 +717,7 @@ def create_ui(wrap_gradio_gpu_call):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+ parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -784,7 +786,7 @@ def create_ui(wrap_gradio_gpu_call):
]
)
- parameters_copypaste.add_paste_fields("txt2img", None, [
+ txt2img_paste_fields = [
(txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"),
(steps, "Steps"),
@@ -805,7 +807,8 @@ def create_ui(wrap_gradio_gpu_call):
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
*modules.scripts.scripts_txt2img.infotext_fields
- ])
+ ]
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
txt2img_preview_params = [
txt2img_prompt,
@@ -893,6 +896,7 @@ def create_ui(wrap_gradio_gpu_call):
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
+ parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -1038,7 +1042,6 @@ def create_ui(wrap_gradio_gpu_call):
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
-
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -1050,12 +1053,8 @@ def create_ui(wrap_gradio_gpu_call):
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
with gr.TabItem('Batch from Directory'):
- extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs,
- placeholder="A directory on the same machine where the server is running."
- )
- extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
- placeholder="Leave blank to save images to the default path."
- )
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
show_extras_results = gr.Checkbox(label='Show result images', value=True)
with gr.Tabs(elem_id="extras_resize_mode"):
@@ -1087,7 +1086,6 @@ def create_ui(wrap_gradio_gpu_call):
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
submit.click(
@@ -1121,7 +1119,6 @@ def create_ui(wrap_gradio_gpu_call):
)
parameters_copypaste.add_paste_fields("extras", extras_image, None)
-
extras_image.change(
fn=modules.extras.clear_cache,
inputs=[], outputs=[]
@@ -1587,9 +1584,6 @@ def create_ui(wrap_gradio_gpu_call):
if column is not None:
column.__exit__()
-
-
-
interfaces = [
(txt2img_interface, "txt2img", "txt2img"),
(img2img_interface, "img2img", "img2img"),
@@ -1599,10 +1593,6 @@ def create_ui(wrap_gradio_gpu_call):
(train_interface, "Train", "ti"),
]
- interfaces += script_callbacks.ui_tabs_callback()
-
- interfaces += [(settings_interface, "Settings", "settings")]
-
css = ""
for cssfile in modules.scripts.list_files_with_name("style.css"):
@@ -1619,6 +1609,9 @@ def create_ui(wrap_gradio_gpu_call):
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
+ interfaces += script_callbacks.ui_tabs_callback()
+ interfaces += [(settings_interface, "Settings", "settings")]
+
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
@@ -1627,6 +1620,9 @@ def create_ui(wrap_gradio_gpu_call):
settings_interface.gradio_ref = demo
+ parameters_copypaste.integrate_settings_paste_fields(component_dict)
+ parameters_copypaste.run_bind()
+
with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
@@ -1681,15 +1677,6 @@ def create_ui(wrap_gradio_gpu_call):
]
)
-
- settings_map = {
- 'sd_hypernetwork': 'Hypernet',
- 'CLIP_stop_at_last_layers': 'Clip skip',
- 'sd_model_checkpoint': 'Model hash',
- }
-
- parameters_copypaste.run_bind()
-
ui_config_file = cmd_opts.ui_config_file
ui_settings = {}
settings_count = len(ui_settings)
@@ -1708,7 +1695,7 @@ def create_ui(wrap_gradio_gpu_call):
def apply_field(obj, field, condition=None, init_field=None):
key = path + "/" + field
- if getattr(obj,'custom_script_source',None) is not None:
+ if getattr(obj, 'custom_script_source', None) is not None:
key = 'customscript/' + obj.custom_script_source + '/' + key
if getattr(obj, 'do_not_save_to_config', False):
--
cgit v1.2.3
From a5f3adbdd7d9b8245f7782216ac48913660e6bb5 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 15:37:24 +0700
Subject: Allow trailing comma in learning rate
---
modules/textual_inversion/learn_schedule.py | 33 +++++++++++++++++------------
1 file changed, 20 insertions(+), 13 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 3a736065..76e611b6 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -11,23 +11,30 @@ class LearnScheduleIterator:
self.rates = []
self.it = 0
self.maxit = 0
- for i, pair in enumerate(pairs):
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
+ try:
+ for i, pair in enumerate(pairs):
+ if not pair.strip():
+ continue
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
return
- elif step == -1:
+ else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
+ assert self.rates
+ except (ValueError, AssertionError):
+ raise Exception("Invalid learning rate schedule")
+
def __iter__(self):
return self
--
cgit v1.2.3
From ef4c94e1cfe66299227aa95a28c2380d21cb1600 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 15:42:51 +0700
Subject: Improve lr schedule error message
---
modules/textual_inversion/learn_schedule.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 76e611b6..dd0c0ad1 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -4,7 +4,7 @@ import tqdm
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
"""
- specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
"""
pairs = learn_rate.split(',')
@@ -33,7 +33,7 @@ class LearnScheduleIterator:
return
assert self.rates
except (ValueError, AssertionError):
- raise Exception("Invalid learning rate schedule")
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
def __iter__(self):
--
cgit v1.2.3
From ab27c111d06ec920791c73eea25ad9a61671852e Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 18:09:17 +0700
Subject: Add input validations before loading dataset for training
---
modules/hypernetworks/hypernetwork.py | 38 +++++++++++---------
modules/textual_inversion/textual_inversion.py | 48 +++++++++++++++++++-------
2 files changed, 58 insertions(+), 28 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 2e84583b..38f35c58 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -332,7 +332,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
- assert hypernetwork_name, 'hypernetwork not selected'
+ save_hypernetwork_every = save_hypernetwork_every or 0
+ create_image_every = create_image_every or 0
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@@ -358,39 +360,43 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else:
images_dir = None
+ hypernetwork = shared.loaded_hypernetwork
+
+ ititial_step = hypernetwork.step or 0
+ if ititial_step > steps:
+ shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ return hypernetwork, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
- hypernetwork = shared.loaded_hypernetwork
- weights = hypernetwork.weights()
- for weight in weights:
- weight.requires_grad = True
-
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
-
- last_saved_file = ""
- last_saved_image = ""
- forced_filename = ""
-
- ititial_step = hypernetwork.step or 0
- if ititial_step > steps:
- return hypernetwork, filename
-
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
steps_without_grad = 0
+ last_saved_file = ""
+ last_saved_image = ""
+ forced_filename = ""
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 17dfb223..44f06443 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+ assert model_name, f"{name} not selected"
+ assert learn_rate, "Learning rate is empty or 0"
+ assert isinstance(batch_size, int), "Batch size must be integer"
+ assert batch_size > 0, "Batch size must be positive"
+ assert data_root, "Dataset directory is empty"
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
+ assert template_file, "Prompt template file is empty"
+ assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+ assert steps, "Max steps is empty or 0"
+ assert isinstance(steps, int), "Max steps must be integer"
+ assert steps > 0 , "Max steps must be positive"
+ assert isinstance(save_model_every, int), "Save {name} must be integer"
+ assert save_model_every >= 0 , "Save {name} must be positive or 0"
+ assert isinstance(create_image_every, int), "Create image must be integer"
+ assert create_image_every >= 0 , "Create image must be positive or 0"
+ if save_model_every or create_image_every:
+ assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- assert embedding_name, 'embedding not selected'
+ save_embedding_every = save_embedding_every or 0
+ create_image_every = create_image_every or 0
+ validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
@@ -232,17 +253,27 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
-
+
cond_model = shared.sd_model.cond_stage_model
+ hijack = sd_hijack.model_hijack
+
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
+
+ ititial_step = embedding.step or 0
+ if ititial_step > steps:
+ shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ return embedding, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
- hijack = sd_hijack.model_hijack
-
- embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
losses = torch.zeros((32,))
@@ -251,13 +282,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
forced_filename = ""
embedding_yet_to_be_embedded = False
- ititial_step = embedding.step or 0
- if ititial_step > steps:
- return embedding, filename
-
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
-
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
--
cgit v1.2.3
From 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 19:43:21 +0700
Subject: Add cleanup after training
---
modules/hypernetworks/hypernetwork.py | 201 +++++++++++++------------
modules/textual_inversion/textual_inversion.py | 185 ++++++++++++-----------
2 files changed, 200 insertions(+), 186 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 38f35c58..170d5ea4 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -398,110 +398,112 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
forced_filename = ""
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, entries in pbar:
- hypernetwork.step = i + ititial_step
- if len(loss_dict) > 0:
- previous_mean_losses = [i[-1] for i in loss_dict.values()]
- previous_mean_loss = mean(previous_mean_losses)
-
- scheduler.apply(optimizer, hypernetwork.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = stack_conds([entry.cond for entry in entries]).to(devices.device)
- # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
- del c
-
- losses[hypernetwork.step % losses.shape[0]] = loss.item()
- for entry in entries:
- loss_dict[entry.filename].append(loss.item())
-
- optimizer.zero_grad()
- weights[0].grad = None
- loss.backward()
- if weights[0].grad is None:
- steps_without_grad += 1
+ try:
+ for i, entries in pbar:
+ hypernetwork.step = i + ititial_step
+ if len(loss_dict) > 0:
+ previous_mean_losses = [i[-1] for i in loss_dict.values()]
+ previous_mean_loss = mean(previous_mean_losses)
+
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = stack_conds([entry.cond for entry in entries]).to(devices.device)
+ # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+ del c
+
+ losses[hypernetwork.step % losses.shape[0]] = loss.item()
+ for entry in entries:
+ loss_dict[entry.filename].append(loss.item())
+
+ optimizer.zero_grad()
+ weights[0].grad = None
+ loss.backward()
+
+ if weights[0].grad is None:
+ steps_without_grad += 1
+ else:
+ steps_without_grad = 0
+ assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
+
+ optimizer.step()
+
+ steps_done = hypernetwork.step + 1
+
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
+ raise RuntimeError("Loss diverged.")
+
+ if len(previous_mean_losses) > 1:
+ std = stdev(previous_mean_losses)
else:
- steps_without_grad = 0
- assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
-
- optimizer.step()
-
- steps_done = hypernetwork.step + 1
-
- if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
- raise RuntimeError("Loss diverged.")
-
- if len(previous_mean_losses) > 1:
- std = stdev(previous_mean_losses)
- else:
- std = 0
- dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
- pbar.set_description(dataset_loss_info)
-
- if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
- # Before saving, change name to match current checkpoint.
- hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
- hypernetwork.save(last_saved_file)
-
- textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
- "loss": f"{previous_mean_loss:.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{hypernetwork_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
-
- optimizer.zero_grad()
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
+ std = 0
+ dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
+ pbar.set_description(dataset_loss_info)
+
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
+ # Before saving, change name to match current checkpoint.
+ hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
+ hypernetwork.save(last_saved_file)
+
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
+ "loss": f"{previous_mean_loss:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
- preview_text = p.prompt
+ preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images)>0 else None
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images)>0 else None
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
- if image is not None:
- shared.state.current_image = image
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
+ if image is not None:
+ shared.state.current_image = image
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
- shared.state.job_no = hypernetwork.step
+ shared.state.job_no = hypernetwork.step
- shared.state.textinfo = f"""
+ shared.state.textinfo = f"""
Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
@@ -510,7 +512,14 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
-
+ finally:
+ if weights:
+ for weight in weights:
+ weight.requires_grad = False
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
report_statistics(loss_dict)
checkpoint = sd_models.select_checkpoint()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 44f06443..fd7f0897 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -283,111 +283,113 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, entries in pbar:
- embedding.step = i + ititial_step
- scheduler.apply(optimizer, embedding.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = cond_model([entry.cond_text for entry in entries])
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
-
- losses[embedding.step % losses.shape[0]] = loss.item()
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- steps_done = embedding.step + 1
-
- epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step % len(ds)
-
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
-
- if embedding_dir is not None and steps_done % save_embedding_every == 0:
- # Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{steps_done}'
- last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
- embedding.save(last_saved_file)
- embedding_yet_to_be_embedded = True
-
- write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{embedding_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- do_not_reload_embeddings=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
- p.width = training_width
- p.height = training_height
+ try:
+ for i, entries in pbar:
+ embedding.step = i + ititial_step
+
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+
+ losses[embedding.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ steps_done = embedding.step + 1
+
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step % len(ds)
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
+
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
+ # Before saving, change name to match current checkpoint.
+ embedding.name = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
+ embedding.save(last_saved_file)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ do_not_reload_embeddings=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
- preview_text = p.prompt
+ preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0]
+ processed = processing.process_images(p)
+ image = processed.images[0]
- shared.state.current_image = image
+ shared.state.current_image = image
- if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
- info = PngImagePlugin.PngInfo()
- data = torch.load(last_saved_file)
- info.add_text("sd-ti-embedding", embedding_to_b64(data))
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
- title = "<{}>".format(data.get('name', '???'))
+ title = "<{}>".format(data.get('name', '???'))
- try:
- vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
- vectorSize = '?'
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
- checkpoint = sd_models.select_checkpoint()
- footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, steps_done)
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
- captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
- captioned_image = insert_image_data_embed(captioned_image, data)
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
- captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
- embedding_yet_to_be_embedded = False
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
- shared.state.job_no = embedding.step
+ shared.state.job_no = embedding.step
- shared.state.textinfo = f"""
+ shared.state.textinfo = f"""
Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -396,6 +398,9 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
+ finally:
+ if embedding and embedding.vec is not None:
+ embedding.vec.requires_grad = False
checkpoint = sd_models.select_checkpoint()
--
cgit v1.2.3
From a27d19de2eff633b6a39f9f4a5c0f2d6abb81bb5 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sat, 29 Oct 2022 19:44:05 +0700
Subject: Additional assert on dataset
---
modules/textual_inversion/dataset.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 8bb00d27..ad726577 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -42,6 +42,8 @@ class PersonalizedBase(Dataset):
self.lines = lines
assert data_root, 'dataset directory not specified'
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model
--
cgit v1.2.3
From de1dc0d279a877d5d9f512befe30a7d7e5cf3881 Mon Sep 17 00:00:00 2001
From: Martin Cairns <4314538+MartinCairnsSQL@users.noreply.github.com>
Date: Sat, 29 Oct 2022 15:23:19 +0100
Subject: Add adjust_steps_if_invalid to find next valid step for ddim uniform
sampler
---
modules/sd_samplers.py | 28 +++++++++++++++-------------
1 file changed, 15 insertions(+), 13 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 3670b57d..aca014e8 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,5 +1,6 @@
from collections import namedtuple
import numpy as np
+from math import floor
import torch
import tqdm
from PIL import Image
@@ -205,17 +206,22 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
+
+ def adjust_steps_if_invalid(self, p, num_steps):
+ if self.config.name == 'DDIM' and p.ddim_discretize == 'uniform':
+ valid_step = 999 / (1000 // num_steps)
+ if valid_step == floor(valid_step):
+ return int(valid_step) + 1
+
+ return num_steps
+
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
-
+ steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p)
- # existing code fails with certain step counts, like 9
- try:
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
- except Exception:
- self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
-
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
@@ -239,18 +245,14 @@ class VanillaStableDiffusionSampler:
self.last_latent = x
self.step = 0
- steps = steps or p.steps
+ steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
- # existing code fails with certain step counts, like 9
- try:
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
- except Exception:
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim
--
cgit v1.2.3
From 44ab954fabb9c1273366ebdca47f8da394d61aab Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Sat, 29 Oct 2022 10:02:56 -0700
Subject: Fix latent upscale highres fix #3888
---
modules/processing.py | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 548eec29..f18b7db2 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -653,6 +653,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ image_conditioning = self.txt2img_image_conditioning(samples)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
@@ -674,6 +675,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ image_conditioning = self.img2img_image_conditioning(
+ decoded_samples,
+ samples,
+ decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
+ )
+
shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
@@ -684,11 +691,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- image_conditioning = self.img2img_image_conditioning(
- decoded_samples,
- samples,
- decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
- )
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples
--
cgit v1.2.3
From ab05a74ead9fabb45dd099990e34061c7eb02ca3 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sun, 30 Oct 2022 00:32:02 +0700
Subject: Revert "Add cleanup after training"
This reverts commit 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1.
---
modules/hypernetworks/hypernetwork.py | 201 ++++++++++++-------------
modules/textual_inversion/textual_inversion.py | 185 +++++++++++------------
2 files changed, 186 insertions(+), 200 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 170d5ea4..38f35c58 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -398,112 +398,110 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
forced_filename = ""
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
-
- try:
- for i, entries in pbar:
- hypernetwork.step = i + ititial_step
- if len(loss_dict) > 0:
- previous_mean_losses = [i[-1] for i in loss_dict.values()]
- previous_mean_loss = mean(previous_mean_losses)
-
- scheduler.apply(optimizer, hypernetwork.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = stack_conds([entry.cond for entry in entries]).to(devices.device)
- # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
- del c
-
- losses[hypernetwork.step % losses.shape[0]] = loss.item()
- for entry in entries:
- loss_dict[entry.filename].append(loss.item())
-
- optimizer.zero_grad()
- weights[0].grad = None
- loss.backward()
-
- if weights[0].grad is None:
- steps_without_grad += 1
- else:
- steps_without_grad = 0
- assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
-
- optimizer.step()
-
- steps_done = hypernetwork.step + 1
-
- if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
- raise RuntimeError("Loss diverged.")
+ for i, entries in pbar:
+ hypernetwork.step = i + ititial_step
+ if len(loss_dict) > 0:
+ previous_mean_losses = [i[-1] for i in loss_dict.values()]
+ previous_mean_loss = mean(previous_mean_losses)
- if len(previous_mean_losses) > 1:
- std = stdev(previous_mean_losses)
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = stack_conds([entry.cond for entry in entries]).to(devices.device)
+ # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+ del c
+
+ losses[hypernetwork.step % losses.shape[0]] = loss.item()
+ for entry in entries:
+ loss_dict[entry.filename].append(loss.item())
+
+ optimizer.zero_grad()
+ weights[0].grad = None
+ loss.backward()
+
+ if weights[0].grad is None:
+ steps_without_grad += 1
else:
- std = 0
- dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
- pbar.set_description(dataset_loss_info)
-
- if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
- # Before saving, change name to match current checkpoint.
- hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
- hypernetwork.save(last_saved_file)
-
- textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
- "loss": f"{previous_mean_loss:.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{hypernetwork_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
-
- optimizer.zero_grad()
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
+ steps_without_grad = 0
+ assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
+ optimizer.step()
- preview_text = p.prompt
+ steps_done = hypernetwork.step + 1
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images)>0 else None
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
+ raise RuntimeError("Loss diverged.")
+
+ if len(previous_mean_losses) > 1:
+ std = stdev(previous_mean_losses)
+ else:
+ std = 0
+ dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
+ pbar.set_description(dataset_loss_info)
+
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
+ # Before saving, change name to match current checkpoint.
+ hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
+ hypernetwork.save(last_saved_file)
+
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
+ "loss": f"{previous_mean_loss:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
- if image is not None:
- shared.state.current_image = image
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images)>0 else None
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
- shared.state.job_no = hypernetwork.step
+ if image is not None:
+ shared.state.current_image = image
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
- shared.state.textinfo = f"""
+ shared.state.job_no = hypernetwork.step
+
+ shared.state.textinfo = f"""
Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
@@ -512,14 +510,7 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
- finally:
- if weights:
- for weight in weights:
- weight.requires_grad = False
- if unload:
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
+
report_statistics(loss_dict)
checkpoint = sd_models.select_checkpoint()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fd7f0897..44f06443 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -283,113 +283,111 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ for i, entries in pbar:
+ embedding.step = i + ititial_step
- try:
- for i, entries in pbar:
- embedding.step = i + ititial_step
-
- scheduler.apply(optimizer, embedding.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = cond_model([entry.cond_text for entry in entries])
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
-
- losses[embedding.step % losses.shape[0]] = loss.item()
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- steps_done = embedding.step + 1
-
- epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step % len(ds)
-
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
-
- if embedding_dir is not None and steps_done % save_embedding_every == 0:
- # Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{steps_done}'
- last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
- embedding.save(last_saved_file)
- embedding_yet_to_be_embedded = True
-
- write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{embedding_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- do_not_reload_embeddings=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
- p.width = training_width
- p.height = training_height
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+
+ losses[embedding.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ steps_done = embedding.step + 1
+
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step % len(ds)
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
+
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
+ # Before saving, change name to match current checkpoint.
+ embedding.name = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
+ embedding.save(last_saved_file)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ do_not_reload_embeddings=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
- preview_text = p.prompt
+ preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0]
+ processed = processing.process_images(p)
+ image = processed.images[0]
- shared.state.current_image = image
+ shared.state.current_image = image
- if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
- info = PngImagePlugin.PngInfo()
- data = torch.load(last_saved_file)
- info.add_text("sd-ti-embedding", embedding_to_b64(data))
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
- title = "<{}>".format(data.get('name', '???'))
+ title = "<{}>".format(data.get('name', '???'))
- try:
- vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
- vectorSize = '?'
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
- checkpoint = sd_models.select_checkpoint()
- footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, steps_done)
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
- captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
- captioned_image = insert_image_data_embed(captioned_image, data)
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
- captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
- embedding_yet_to_be_embedded = False
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
- last_saved_image += f", prompt: {preview_text}"
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
- shared.state.job_no = embedding.step
+ shared.state.job_no = embedding.step
- shared.state.textinfo = f"""
+ shared.state.textinfo = f"""
Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -398,9 +396,6 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
- finally:
- if embedding and embedding.vec is not None:
- embedding.vec.requires_grad = False
checkpoint = sd_models.select_checkpoint()
--
cgit v1.2.3
From 6e2ce4e735db64afcd0fe637327ca4ec78335706 Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Sat, 29 Oct 2022 10:35:51 -0700
Subject: Added image conditioning to latent upscale. Only comuted if the mask
weight is not 1.0 to avoid extra memory. Also includes some code cleanup.
---
modules/processing.py | 29 +++++++++++------------------
1 file changed, 11 insertions(+), 18 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index f18b7db2..ee0e9e34 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -134,11 +134,7 @@ class StableDiffusionProcessing():
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return torch.zeros(
- x.shape[0], 5, 1, 1,
- dtype=x.dtype,
- device=x.device
- )
+ return x.new_zeros(x.shape[0], 5, 1, 1)
height = height or self.height
width = width or self.width
@@ -156,11 +152,7 @@ class StableDiffusionProcessing():
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
- return torch.zeros(
- latent_image.shape[0], 5, 1, 1,
- dtype=latent_image.dtype,
- device=latent_image.device
- )
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
# Handle the different mask inputs
if image_mask is not None:
@@ -174,11 +166,10 @@ class StableDiffusionProcessing():
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
- conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
- conditioning_mask = conditioning_mask.to(source_image.device)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
@@ -653,7 +644,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
- image_conditioning = self.txt2img_image_conditioning(samples)
+
+ # Avoid making the inpainting conditioning unless necessary as
+ # this does need some extra compute to decode / encode the image again.
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
+ else:
+ image_conditioning = self.txt2img_image_conditioning(samples)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
@@ -675,11 +672,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
- image_conditioning = self.img2img_image_conditioning(
- decoded_samples,
- samples,
- decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
- )
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
shared.state.nextjob()
--
cgit v1.2.3
From a07f054c86f33360ff620d6a3fffdee366ab2d99 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sun, 30 Oct 2022 00:49:29 +0700
Subject: Add missing info on hypernetwork/embedding model log
Mentioned here: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/1528#discussioncomment-3991513
Also group the saving into one
---
modules/hypernetworks/hypernetwork.py | 31 +++++++++++++-------
modules/textual_inversion/textual_inversion.py | 39 +++++++++++++++++---------
2 files changed, 47 insertions(+), 23 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 38f35c58..86daf825 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -361,6 +361,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
images_dir = None
hypernetwork = shared.loaded_hypernetwork
+ checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
@@ -449,9 +450,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
- hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
- hypernetwork.save(last_saved_file)
+ hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
@@ -512,13 +513,23 @@ Last saved image: {html.escape(last_saved_image)}
"""
report_statistics(loss_dict)
- checkpoint = sd_models.select_checkpoint()
- hypernetwork.sd_checkpoint = checkpoint.hash
- hypernetwork.sd_checkpoint_name = checkpoint.model_name
- # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
- hypernetwork.name = hypernetwork_name
- filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
- hypernetwork.save(filename)
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
return hypernetwork, filename
+
+def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
+ old_hypernetwork_name = hypernetwork.name
+ old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
+ old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
+ try:
+ hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ hypernetwork.name = hypernetwork_name
+ hypernetwork.save(filename)
+ except:
+ hypernetwork.sd_checkpoint = old_sd_checkpoint
+ hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
+ hypernetwork.name = old_hypernetwork_name
+ raise
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 44f06443..ee9917ce 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -119,7 +119,7 @@ class EmbeddingDatabase:
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('hash', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
@@ -259,6 +259,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
+ checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0
if ititial_step > steps:
@@ -314,9 +315,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
- embedding.name = f'{embedding_name}-{steps_done}'
- last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
- embedding.save(last_saved_file)
+ embedding_name_every = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
+ save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
@@ -397,14 +398,26 @@ Last saved image: {html.escape(last_saved_image)}
"""
- checkpoint = sd_models.select_checkpoint()
-
- embedding.sd_checkpoint = checkpoint.hash
- embedding.sd_checkpoint_name = checkpoint.model_name
- embedding.cached_checksum = None
- # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
- embedding.name = embedding_name
- filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
- embedding.save(filename)
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+ save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
return embedding, filename
+
+def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+ old_embedding_name = embedding.name
+ old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
+ old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
+ old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
+ try:
+ embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint_name = checkpoint.model_name
+ if remove_cached_checksum:
+ embedding.cached_checksum = None
+ embedding.name = embedding_name
+ embedding.save(filename)
+ except:
+ embedding.sd_checkpoint = old_sd_checkpoint
+ embedding.sd_checkpoint_name = old_sd_checkpoint_name
+ embedding.name = old_embedding_name
+ embedding.cached_checksum = old_cached_checksum
+ raise
--
cgit v1.2.3
From 3d58510f214c645ce5cdb261aa47df6573b239e9 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sun, 30 Oct 2022 00:54:59 +0700
Subject: Fix dataset still being loaded even when training will be skipped
---
modules/hypernetworks/hypernetwork.py | 2 +-
modules/textual_inversion/textual_inversion.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 86daf825..07acadc9 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -364,7 +364,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0
- if ititial_step > steps:
+ if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index ee9917ce..e0babb46 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -262,7 +262,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0
- if ititial_step > steps:
+ if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return embedding, filename
--
cgit v1.2.3
From 4609b83cd496013a05e77c42af031d89f07785a9 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sat, 29 Oct 2022 16:09:19 -0300
Subject: Add PNG Info endpoint
---
modules/api/api.py | 12 +++++++++---
modules/api/models.py | 9 ++++++++-
2 files changed, 17 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 49c213ea..8fcd068d 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -5,7 +5,7 @@ import modules.shared as shared
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
-from modules.extras import run_extras
+from modules.extras import run_extras, run_pnginfo
def upscaler_to_index(name: str):
try:
@@ -32,6 +32,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
+ self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -125,8 +126,13 @@ class Api:
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
- def pnginfoapi(self):
- raise NotImplementedError
+ def pnginfoapi(self, req:PNGInfoRequest):
+ if(not req.image.strip()):
+ return PNGInfoResponse(info="")
+
+ result = run_pnginfo(decode_base64_to_image(req.image.strip()))
+
+ return PNGInfoResponse(info=result[1])
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index dd122321..58e8e58b 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,4 +1,5 @@
import inspect
+from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
@@ -148,4 +149,10 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: list[str] = Field(title="Images", description="The generated images in base64 format.")
\ No newline at end of file
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
+
+class PNGInfoRequest(BaseModel):
+ image: str = Field(title="Image", description="The base64 encoded PNG image")
+
+class PNGInfoResponse(BaseModel):
+ info: str = Field(title="Image info", description="A string with all the info the image had")
\ No newline at end of file
--
cgit v1.2.3
From 83a1f44ae26cb89492064bb8be0321b14a75efe4 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sat, 29 Oct 2022 16:10:00 -0300
Subject: Fix space
---
modules/api/api.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 8fcd068d..d0f488ca 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -126,7 +126,7 @@ class Api:
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
- def pnginfoapi(self, req:PNGInfoRequest):
+ def pnginfoapi(self, req: PNGInfoRequest):
if(not req.image.strip()):
return PNGInfoResponse(info="")
--
cgit v1.2.3
From 9bb6b6509aff8c1e6546d5a798ef9e9922758dc4 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 29 Oct 2022 22:20:02 +0300
Subject: add postprocess call for scripts
---
modules/processing.py | 12 +++++++++---
modules/scripts.py | 24 +++++++++++++++++++++---
2 files changed, 30 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 548eec29..50343846 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -478,7 +478,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
model_hijack.embedding_db.load_textual_inversion_embeddings()
if p.scripts is not None:
- p.scripts.run_alwayson_scripts(p)
+ p.scripts.process(p)
infotexts = []
output_images = []
@@ -501,7 +501,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- if (len(prompts) == 0):
+ if len(prompts) == 0:
break
with devices.autocast():
@@ -590,7 +590,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ if p.scripts is not None:
+ p.scripts.postprocess(p, res)
+
+ return res
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
diff --git a/modules/scripts.py b/modules/scripts.py
index a7f36012..96e44bfd 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -64,7 +64,16 @@ class Script:
def process(self, p, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
- scripts. You can modify the processing object (p) here, inject hooks, etc.
+ You can modify the processing object (p) here, inject hooks, etc.
+ args contains all values returned by components from ui()
+ """
+
+ pass
+
+ def postprocess(self, p, processed, *args):
+ """
+ This function is called after processing ends for AlwaysVisible scripts.
+ args contains all values returned by components from ui()
"""
pass
@@ -289,13 +298,22 @@ class ScriptRunner:
return processed
- def run_alwayson_scripts(self, p):
+ def process(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
+ print(f"Error running process: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def postprocess(self, p, processed):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess(p, processed, *script_args)
+ except Exception:
+ print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def reload_sources(self, cache):
--
cgit v1.2.3
From f62db4d5c753bc32d2ae166606ce41f4c5fa5c43 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 03:55:43 +0800
Subject: fix progress response model
---
modules/api/api.py | 30 ------------------------------
modules/api/models.py | 8 ++++----
2 files changed, 4 insertions(+), 34 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index e93cddcb..7e8522a2 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,33 +1,3 @@
-# import time
-
-# from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
-# from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-# from modules.sd_samplers import all_samplers
-# from modules.extras import run_pnginfo
-# import modules.shared as shared
-# from modules import devices
-# import uvicorn
-# from fastapi import Body, APIRouter, HTTPException
-# from fastapi.responses import JSONResponse
-# from pydantic import BaseModel, Field, Json
-# from typing import List
-# import json
-# import io
-# import base64
-# from PIL import Image
-
-# sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
-
-# class TextToImageResponse(BaseModel):
-# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
-# parameters: Json
-# info: Json
-
-# class ImageToImageResponse(BaseModel):
-# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
-# parameters: Json
-# info: Json
-
import time
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
diff --git a/modules/api/models.py b/modules/api/models.py
index 8d4abc39..e1762fb9 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,6 +1,6 @@
import inspect
from click import prompt
-from pydantic import BaseModel, Field, create_model
+from pydantic import BaseModel, Field, Json, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
@@ -158,6 +158,6 @@ class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with all the info the image had")
class ProgressResponse(BaseModel):
- progress: float
- eta_relative: float
- state: dict
+ progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
+ eta_relative: float = Field(title="ETA in secs")
+ state: Json
--
cgit v1.2.3
From e9c6c2a51f972fd7cd88ea740ade4ac3d8108b67 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 04:02:56 +0800
Subject: add description for state field
---
modules/api/models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/api/models.py b/modules/api/models.py
index e1762fb9..709ab5a6 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -160,4 +160,4 @@ class PNGInfoResponse(BaseModel):
class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
- state: Json
+ state: Json = Field(title="State", description="The current state snapshot")
--
cgit v1.2.3
From 88f46a5bec610cf03641f18becbe3deda541e982 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 05:04:29 +0800
Subject: update progress response model
---
modules/api/api.py | 6 +++---
modules/api/models.py | 4 ++--
modules/shared.py | 4 ++--
3 files changed, 7 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 7e8522a2..5912d289 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -61,7 +61,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
- self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
+ self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -171,7 +171,7 @@ class Api:
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
- return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
+ return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
# avoid dividing zero
progress = 0.01
@@ -187,7 +187,7 @@ class Api:
progress = min(progress, 1)
- return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict())
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index 709ab5a6..0ab85ec5 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,6 +1,6 @@
import inspect
from click import prompt
-from pydantic import BaseModel, Field, Json, create_model
+from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
@@ -160,4 +160,4 @@ class PNGInfoResponse(BaseModel):
class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
- state: Json = Field(title="State", description="The current state snapshot")
+ state: dict = Field(title="State", description="The current state snapshot")
diff --git a/modules/shared.py b/modules/shared.py
index 0f4c035d..f7b0990c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -147,7 +147,7 @@ class State:
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
- def js(self):
+ def dict(self):
obj = {
"skipped": self.skipped,
"interrupted": self.skipped,
@@ -158,7 +158,7 @@ class State:
"sampling_steps": self.sampling_steps,
}
- return json.dumps(obj)
+ return obj
state = State()
--
cgit v1.2.3
From 39f55c3c35873bc7dd9792cb2155746a1c3d4292 Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Sat, 29 Oct 2022 14:13:02 -0700
Subject: Re-add explicit device move
---
modules/processing.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index ee0e9e34..d07e3db9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -170,6 +170,7 @@ class StableDiffusionProcessing():
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
+ conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
--
cgit v1.2.3
From 9f104b53c425e248595e5b6481336d2a339e015e Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 05:19:17 +0800
Subject: preview current image when opts.show_progress_every_n_steps is
enabled
---
modules/api/api.py | 8 ++++++--
modules/api/models.py | 1 +
2 files changed, 7 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 5912d289..e960bb7b 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,7 +1,7 @@
import time
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
-from fastapi import APIRouter, HTTPException
+from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
from modules import devices
from modules.api.models import *
@@ -187,7 +187,11 @@ class Api:
progress = min(progress, 1)
- return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict())
+ current_image = None
+ if shared.state.current_image:
+ current_image = encode_pil_to_base64(shared.state.current_image)
+
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index 0ab85ec5..c8bc719a 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -161,3 +161,4 @@ class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
+ current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
--
cgit v1.2.3
From 66d038f6a41507af2243ff1f6618a745a092c290 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Sat, 29 Oct 2022 15:00:08 -0700
Subject: Read hypernet strength from PNG info.
---
modules/generation_parameters_copypaste.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index bbaad42e..59c6d7da 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -66,6 +66,7 @@ def integrate_settings_paste_fields(component_dict):
settings_map = {
'sd_hypernetwork': 'Hypernet',
+ 'sd_hypernetwork_strength': 'Hypernetwork strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
--
cgit v1.2.3
From 9f4f894d74b57c3d02ebccaa59f9c22fca2b6c90 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 06:03:32 +0800
Subject: allow skip current image in progress api
---
modules/api/api.py | 4 ++--
modules/api/models.py | 3 +++
2 files changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index e960bb7b..5c5b210f 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -167,7 +167,7 @@ class Api:
return PNGInfoResponse(info=result[1])
- def progressapi(self):
+ def progressapi(self, req: ProgressRequest = Depends()):
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
@@ -188,7 +188,7 @@ class Api:
progress = min(progress, 1)
current_image = None
- if shared.state.current_image:
+ if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
diff --git a/modules/api/models.py b/modules/api/models.py
index c8bc719a..9ee42a17 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -157,6 +157,9 @@ class PNGInfoRequest(BaseModel):
class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with all the info the image had")
+class ProgressRequest(BaseModel):
+ skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
+
class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
--
cgit v1.2.3
From 05a657dd357eaca6940c4775daa946bd33f1167d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 30 Oct 2022 07:36:56 +0300
Subject: fix broken hires fix
---
modules/processing.py | 7 ++-----
1 file changed, 2 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 50343846..947ce6fa 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -686,15 +686,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ image_conditioning = self.txt2img_image_conditioning(x)
+
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
- image_conditioning = self.img2img_image_conditioning(
- decoded_samples,
- samples,
- decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
- )
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples
--
cgit v1.2.3
From 61836bd544fc8f4ef62f311c9d5964fbdaeb3f4c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 30 Oct 2022 08:48:53 +0300
Subject: shorten Hypernetwork strength in infotext and omit it when it's the
default value.
---
modules/generation_parameters_copypaste.py | 2 +-
modules/processing.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 59c6d7da..df70c728 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -66,7 +66,7 @@ def integrate_settings_paste_fields(component_dict):
settings_map = {
'sd_hypernetwork': 'Hypernet',
- 'sd_hypernetwork_strength': 'Hypernetwork strength',
+ 'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
diff --git a/modules/processing.py b/modules/processing.py
index ecaa78e2..b1df4918 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -396,7 +396,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
- "Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength),
+ "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
--
cgit v1.2.3
From 149784202cca8612b43629c601ee27cfda64e623 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 30 Oct 2022 09:10:22 +0300
Subject: rework #3722 to not introduce duplicate code
---
modules/api/api.py | 43 +++++++++++++------------------------------
modules/shared.py | 22 +++++++++++++++++++---
2 files changed, 32 insertions(+), 33 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 5c5b210f..6c06d449 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -9,31 +9,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
-# copy from wrap_gradio_gpu_call of webui.py
-# because queue lock will be acquired in api handlers
-# and time start needs to be set
-# the function has been modified into two parts
-
-def before_gpu_call():
- devices.torch_gc()
-
- shared.state.sampling_step = 0
- shared.state.job_count = -1
- shared.state.job_no = 0
- shared.state.job_timestamp = shared.state.get_job_timestamp()
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.current_image_sampling_step = 0
- shared.state.skipped = False
- shared.state.interrupted = False
- shared.state.textinfo = None
- shared.state.time_start = time.time()
-
-def after_gpu_call():
- shared.state.job = ""
- shared.state.job_count = 0
-
- devices.torch_gc()
def upscaler_to_index(name: str):
try:
@@ -41,8 +16,10 @@ def upscaler_to_index(name: str):
except:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
+
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
@@ -51,6 +28,7 @@ def setUpscalers(req: dict):
reqDict.pop('upscaler_2')
return reqDict
+
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
@@ -78,10 +56,13 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
- before_gpu_call()
+
+ shared.state.begin()
+
with self.queue_lock:
processed = process_images(p)
- after_gpu_call()
+
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
@@ -119,11 +100,13 @@ class Api:
imgs = [img] * p.batch_size
p.init_images = imgs
- # Override object param
- before_gpu_call()
+
+ shared.state.begin()
+
with self.queue_lock:
processed = process_images(p)
- after_gpu_call()
+
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
diff --git a/modules/shared.py b/modules/shared.py
index f7b0990c..e4f163c1 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -144,9 +144,6 @@ class State:
self.sampling_step = 0
self.current_image_sampling_step = 0
- def get_job_timestamp(self):
- return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
-
def dict(self):
obj = {
"skipped": self.skipped,
@@ -160,6 +157,25 @@ class State:
return obj
+ def begin(self):
+ self.sampling_step = 0
+ self.job_count = -1
+ self.job_no = 0
+ self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ self.current_latent = None
+ self.current_image = None
+ self.current_image_sampling_step = 0
+ self.skipped = False
+ self.interrupted = False
+ self.textinfo = None
+
+ devices.torch_gc()
+
+ def end(self):
+ self.job = ""
+ self.job_count = 0
+
+ devices.torch_gc()
state = State()
--
cgit v1.2.3
From 71571e3f055237d71ba2d47756846ad1d73be00c Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Sun, 30 Oct 2022 00:35:40 -0700
Subject: Replaced master branch fix with updated fix.
---
modules/processing.py | 2 --
1 file changed, 2 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 3dd44d3a..512c484f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -688,8 +688,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- image_conditioning = self.txt2img_image_conditioning(x)
-
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
--
cgit v1.2.3
From be27fd4690b1eb6c74da1e31c9696a0f1901fbba Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 17:01:01 +0800
Subject: fix broken progress api by previous rework
---
modules/shared.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index e4f163c1..2c7d28a5 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -4,6 +4,7 @@ import json
import os
import sys
from collections import OrderedDict
+import time
import gradio as gr
import tqdm
@@ -132,6 +133,7 @@ class State:
current_image = None
current_image_sampling_step = 0
textinfo = None
+ time_start = None
def skip(self):
self.skipped = True
@@ -168,6 +170,7 @@ class State:
self.skipped = False
self.interrupted = False
self.textinfo = None
+ self.time_start = time.time()
devices.torch_gc()
--
cgit v1.2.3
From 1a4ff2de6a835cd8cc1590bbc1a8dedb5ad37e5b Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 17:02:47 +0800
Subject: fix current image in progress api when parallel processing enabled
---
modules/api/api.py | 13 +++++++++++--
1 file changed, 11 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 6c06d449..97497f3f 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -3,10 +3,9 @@ import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
-from modules import devices
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.sd_samplers import all_samplers
+from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
from modules.extras import run_extras, run_pnginfo
@@ -170,6 +169,16 @@ class Api:
progress = min(progress, 1)
+ # copy from check_progress_call of ui.py
+
+ if shared.parallel_processing_allowed:
+ if shared.state.sampling_step - shared.state.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.state.current_latent is not None:
+ if shared.opts.show_progress_grid:
+ shared.state.current_image = samples_to_image_grid(shared.state.current_latent)
+ else:
+ shared.state.current_image = sample_to_image(shared.state.current_latent)
+ shared.state.current_image_sampling_step = shared.state.sampling_step
+
current_image = None
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
--
cgit v1.2.3
From 34c86c12b0a9d650d4e7c5be478bca34ad8ed048 Mon Sep 17 00:00:00 2001
From: Martin Cairns <4314538+MartinCairnsSQL@users.noreply.github.com>
Date: Sun, 30 Oct 2022 11:04:27 +0000
Subject: Include PLMS in adjust steps as it also can fail in the same way
---
modules/sd_samplers.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index aca014e8..8772db56 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -208,7 +208,7 @@ class VanillaStableDiffusionSampler:
def adjust_steps_if_invalid(self, p, num_steps):
- if self.config.name == 'DDIM' and p.ddim_discretize == 'uniform':
+ if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == floor(valid_step):
return int(valid_step) + 1
--
cgit v1.2.3
From 4b8a192f680101de247dca79e48974b53bf961fe Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sat, 29 Oct 2022 16:36:43 +0900
Subject: add optimizer save option to shared.opts
---
modules/shared.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index e4f163c1..065b893d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -286,6 +286,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
--
cgit v1.2.3
From 20194fd9752a280306fb66b57b258609b0918c46 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sat, 29 Oct 2022 16:56:42 +0900
Subject: We have duplicate linear now
---
modules/hypernetworks/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index aad09ffc..c2d4b51c 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -9,7 +9,7 @@ from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
not_available = ["hardswish", "multiheadattention"]
-keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
--
cgit v1.2.3
From 9d96d7d0a0aa0a966a9aefd24342345eb65952ed Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Sun, 30 Oct 2022 20:39:04 +0900
Subject: resolve conflicts
---
modules/hypernetworks/hypernetwork.py | 44 ++++++++++++++++++++++++++++++-----
1 file changed, 38 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index a11e01d6..8f74cdea 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -21,6 +21,7 @@ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_norm
from collections import defaultdict, deque
from statistics import stdev, mean
+optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
@@ -139,6 +140,8 @@ class Hypernetwork:
self.weight_init = weight_init
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
+ self.optimizer_name = None
+ self.optimizer_state_dict = None
for size in enable_sizes or []:
self.layers[size] = (
@@ -171,6 +174,10 @@ class Hypernetwork:
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+ if self.optimizer_name is not None:
+ state_dict['optimizer_name'] = self.optimizer_name
+ if self.optimizer_state_dict:
+ state_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(state_dict, filename)
@@ -190,7 +197,14 @@ class Hypernetwork:
self.add_layer_norm = state_dict.get('is_layer_norm', False)
print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False)
- print(f"Dropout usage is set to {self.use_dropout}" )
+ print(f"Dropout usage is set to {self.use_dropout}")
+ self.optimizer_name = state_dict.get('optimizer_name', 'AdamW')
+ print(f"Optimizer name is {self.optimizer_name}")
+ self.optimizer_state_dict = state_dict.get('optimizer_state_dict', None)
+ if self.optimizer_state_dict:
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
@@ -392,8 +406,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
- # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
- optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ # Here we use optimizer from saved HN, or we can specify as UI option.
+ if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict:
+ optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
+ else:
+ print(f"Optimizer type {optimizer_name} is not defined!")
+ optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
+ optimizer_name = 'AdamW'
+ if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
+ try:
+ optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
+ except RuntimeError as e:
+ print("Cannot resume from saved optimizer!")
+ print(e)
steps_without_grad = 0
@@ -455,8 +480,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
-
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
@@ -514,14 +542,18 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
-
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
-
+ del optimizer
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
return hypernetwork, filename
+
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_hypernetwork_name = hypernetwork.name
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
--
cgit v1.2.3
From c9bb33dd43dbb9479ff1b70351df14508c89ac60 Mon Sep 17 00:00:00 2001
From: victorca25
Date: Sun, 30 Oct 2022 12:52:50 +0100
Subject: add resrgan 8x, allow use 1x and up to 8x extra models, move BSRGAN
model, add nearest
---
modules/esrgan_model.py | 17 +++++++++++++----
modules/modelloader.py | 3 +++
modules/ui.py | 2 +-
modules/upscaler.py | 17 ++++++++++++++++-
4 files changed, 33 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index a13cf6ac..c61669b4 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -50,6 +50,7 @@ def mod2normal(state_dict):
def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
+ re8x = 0
crt_net = {}
items = []
for k, v in state_dict.items():
@@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23):
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
- crt_net['model.8.weight'] = state_dict['conv_hr.weight']
- crt_net['model.8.bias'] = state_dict['conv_hr.bias']
- crt_net['model.10.weight'] = state_dict['conv_last.weight']
- crt_net['model.10.bias'] = state_dict['conv_last.bias']
+
+ if 'conv_up3.weight' in state_dict:
+ # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
+ re8x = 3
+ crt_net['model.9.weight'] = state_dict['conv_up3.weight']
+ crt_net['model.9.bias'] = state_dict['conv_up3.bias']
+
+ crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
+ crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
+ crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
+ crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
+
state_dict = crt_net
return state_dict
diff --git a/modules/modelloader.py b/modules/modelloader.py
index b0f2f33d..e4a6f8ac 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -85,6 +85,9 @@ def cleanup_models():
src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path)
+ src_path = os.path.join(models_path, "BSRGAN")
+ dest_path = os.path.join(models_path, "ESRGAN")
+ move_files(src_path, dest_path, ".pth")
src_path = os.path.join(root_path, "gfpgan")
dest_path = os.path.join(models_path, "GFPGAN")
move_files(src_path, dest_path)
diff --git a/modules/ui.py b/modules/ui.py
index 5055ca64..47610f5c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1059,7 +1059,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'):
- upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
with gr.TabItem('Scale to'):
with gr.Group():
with gr.Row():
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 6ab2fb40..83fde7ca 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -10,6 +10,7 @@ import modules.shared
from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
from modules.paths import models_path
@@ -57,7 +58,7 @@ class Upscaler:
dest_w = img.width * scale
dest_h = img.height * scale
for i in range(3):
- if img.width >= dest_w and img.height >= dest_h:
+ if img.width > dest_w and img.height > dest_h:
break
img = self.do_upscale(img, selected_model)
if img.width != dest_w or img.height != dest_h:
@@ -120,3 +121,17 @@ class UpscalerLanczos(Upscaler):
self.name = "Lanczos"
self.scalers = [UpscalerData("Lanczos", None, self)]
+
+class UpscalerNearest(Upscaler):
+ scalers = []
+
+ def do_upscale(self, img, selected_model=None):
+ return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
+
+ def load_model(self, _):
+ pass
+
+ def __init__(self, dirname=None):
+ super().__init__(False)
+ self.name = "Nearest"
+ self.scalers = [UpscalerData("Nearest", None, self)]
\ No newline at end of file
--
cgit v1.2.3
From 423f22228306ae72d0480e25add9777c3c5d8fdf Mon Sep 17 00:00:00 2001
From: Maiko Tan
Date: Sun, 30 Oct 2022 22:46:43 +0800
Subject: feat: add app started callback
---
modules/script_callbacks.py | 15 +++++++++++++++
1 file changed, 15 insertions(+)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 6ea58d61..f5509629 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -3,6 +3,8 @@ import traceback
from collections import namedtuple
import inspect
+from fastapi import FastAPI
+from gradio import Blocks
def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
@@ -25,6 +27,7 @@ class ImageSaveParams:
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
+callbacks_app_started = []
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
@@ -40,6 +43,14 @@ def clear_callbacks():
callbacks_image_saved.clear()
+def app_started_callback(demo: Blocks, app: FastAPI):
+ for c in callbacks_app_started:
+ try:
+ c.callback(demo, app)
+ except Exception:
+ report_exception(c, 'app_started_callback')
+
+
def model_loaded_callback(sd_model):
for c in callbacks_model_loaded:
try:
@@ -91,6 +102,10 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
+def on_app_started(callback):
+ add_callback(callbacks_app_started, callback)
+
+
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
--
cgit v1.2.3
From cb31abcf58ea1f64266e6d821937eed058c35f4d Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sun, 30 Oct 2022 21:54:31 +0700
Subject: Settings to select VAE
---
modules/sd_models.py | 31 +++++--------
modules/sd_vae.py | 121 +++++++++++++++++++++++++++++++++++++++++++++++++++
modules/shared.py | 8 ++--
3 files changed, 136 insertions(+), 24 deletions(-)
create mode 100644 modules/sd_vae.py
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f86dc3ed..91ad4b5e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -8,7 +8,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices, script_callbacks
+from modules import shared, modelloader, devices, script_callbacks, sd_vae
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
@@ -160,12 +160,11 @@ def get_state_dict_from_checkpoint(pl_sd):
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
-
-def load_model_weights(model, checkpoint_info):
+def load_model_weights(model, checkpoint_info, force=False):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
- if checkpoint_info not in checkpoints_loaded:
+ if force or checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
@@ -186,17 +185,7 @@ def load_model_weights(model, checkpoint_info):
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
- vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
-
- if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
- vae_file = shared.cmd_opts.vae_path
-
- if os.path.exists(vae_file):
- print(f"Loading VAE weights from: {vae_file}")
- vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
- model.first_stage_model.load_state_dict(vae_dict)
-
+ sd_vae.load_vae(model, checkpoint_file)
model.first_stage_model.to(devices.dtype_vae)
if shared.opts.sd_checkpoint_cache > 0:
@@ -213,7 +202,7 @@ def load_model_weights(model, checkpoint_info):
model.sd_checkpoint_info = checkpoint_info
-def load_model(checkpoint_info=None):
+def load_model(checkpoint_info=None, force=False):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -234,7 +223,7 @@ def load_model(checkpoint_info=None):
do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model)
- load_model_weights(sd_model, checkpoint_info)
+ load_model_weights(sd_model, checkpoint_info, force=force)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
@@ -252,16 +241,16 @@ def load_model(checkpoint_info=None):
return sd_model
-def reload_model_weights(sd_model, info=None):
+def reload_model_weights(sd_model, info=None, force=False):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
- if sd_model.sd_model_checkpoint == checkpoint_info.filename:
+ if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
- load_model(checkpoint_info)
+ load_model(checkpoint_info, force=force)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -271,7 +260,7 @@ def reload_model_weights(sd_model, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_model_weights(sd_model, checkpoint_info)
+ load_model_weights(sd_model, checkpoint_info, force=force)
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
new file mode 100644
index 00000000..82764e55
--- /dev/null
+++ b/modules/sd_vae.py
@@ -0,0 +1,121 @@
+import torch
+import os
+from collections import namedtuple
+from modules import shared, devices
+from modules.paths import models_path
+import glob
+
+model_dir = "Stable-diffusion"
+model_path = os.path.abspath(os.path.join(models_path, model_dir))
+vae_dir = "VAE"
+vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
+
+vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
+default_vae_dict = {"auto": "auto", "None": "None"}
+default_vae_list = ["auto", "None"]
+default_vae_values = [default_vae_dict[x] for x in default_vae_list]
+vae_dict = dict(default_vae_dict)
+vae_list = list(default_vae_list)
+first_load = True
+
+def get_filename(filepath):
+ return os.path.splitext(os.path.basename(filepath))[0]
+
+def refresh_vae_list(vae_path=vae_path, model_path=model_path):
+ global vae_dict, vae_list
+ res = {}
+ candidates = [
+ *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
+ *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
+ *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
+ *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True)
+ ]
+ if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
+ candidates.append(shared.cmd_opts.vae_path)
+ for filepath in candidates:
+ name = get_filename(filepath)
+ res[name] = filepath
+ vae_list.clear()
+ vae_list.extend(default_vae_list)
+ vae_list.extend(list(res.keys()))
+ vae_dict.clear()
+ vae_dict.update(default_vae_dict)
+ vae_dict.update(res)
+ return vae_list
+
+def load_vae(model, checkpoint_file, vae_file="auto"):
+ global first_load, vae_dict, vae_list
+ # save_settings = False
+
+ # if vae_file argument is provided, it takes priority
+ if vae_file and vae_file not in default_vae_list:
+ if not os.path.isfile(vae_file):
+ vae_file = "auto"
+ # save_settings = True
+ print("VAE provided as function argument doesn't exist")
+ # for the first load, if vae-path is provided, it takes priority and failure is reported
+ if first_load and shared.cmd_opts.vae_path is not None:
+ if os.path.isfile(shared.cmd_opts.vae_path):
+ vae_file = shared.cmd_opts.vae_path
+ # save_settings = True
+ # print("Using VAE provided as command line argument")
+ else:
+ print("VAE provided as command line argument doesn't exist")
+ # else, we load from settings
+ if vae_file == "auto" and shared.opts.sd_vae is not None:
+ # if saved VAE settings isn't recognized, fallback to auto
+ vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
+ # if VAE selected but not found, fallback to auto
+ if vae_file not in default_vae_values and not os.path.isfile(vae_file):
+ vae_file = "auto"
+ print("Selected VAE doesn't exist")
+ # vae-path cmd arg takes priority for auto
+ if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
+ if os.path.isfile(shared.cmd_opts.vae_path):
+ vae_file = shared.cmd_opts.vae_path
+ print("Using VAE provided as command line argument")
+ # if still not found, try look for ".vae.pt" beside model
+ model_path = os.path.splitext(checkpoint_file)[0]
+ if vae_file == "auto":
+ vae_file_try = model_path + ".vae.pt"
+ if os.path.isfile(vae_file_try):
+ vae_file = vae_file_try
+ print("Using VAE found beside selected model")
+ # if still not found, try look for ".vae.ckpt" beside model
+ if vae_file == "auto":
+ vae_file_try = model_path + ".vae.ckpt"
+ if os.path.isfile(vae_file_try):
+ vae_file = vae_file_try
+ print("Using VAE found beside selected model")
+ # No more fallbacks for auto
+ if vae_file == "auto":
+ vae_file = None
+ # Last check, just because
+ if vae_file and not os.path.exists(vae_file):
+ vae_file = None
+
+ if vae_file:
+ print(f"Loading VAE weights from: {vae_file}")
+ vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ model.first_stage_model.load_state_dict(vae_dict_1)
+
+ # If vae used is not in dict, update it
+ # It will be removed on refresh though
+ if vae_file is not None:
+ vae_opt = get_filename(vae_file)
+ if vae_opt not in vae_dict:
+ vae_dict[vae_opt] = vae_file
+ vae_list.append(vae_opt)
+
+ """
+ # Save current VAE to VAE settings, maybe? will it work?
+ if save_settings:
+ if vae_file is None:
+ vae_opt = "None"
+
+ # shared.opts.sd_vae = vae_opt
+ """
+
+ first_load = False
+ model.first_stage_model.to(devices.dtype_vae)
diff --git a/modules/shared.py b/modules/shared.py
index e4f163c1..06440ac4 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -14,7 +14,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers, sd_models, localization
+from modules import sd_samplers, sd_models, localization, sd_vae
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@@ -295,6 +295,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
@@ -407,11 +408,12 @@ class Options:
if bad_settings > 0:
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
- def onchange(self, key, func):
+ def onchange(self, key, func, call=True):
item = self.data_labels.get(key)
item.onchange = func
- func()
+ if call:
+ func()
def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
--
cgit v1.2.3
From e1b2ea6e0012ecc988385fc523d8fb50ea5d6be5 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Sun, 30 Oct 2022 22:11:45 +0700
Subject: Change VAE search order and thus priority
---
modules/sd_vae.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 82764e55..0767b925 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -25,10 +25,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
global vae_dict, vae_list
res = {}
candidates = [
- *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
+ *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True)
+ *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
--
cgit v1.2.3
From d587586d3be2de061238defb8a556f03743287f6 Mon Sep 17 00:00:00 2001
From: mawr
Date: Mon, 31 Oct 2022 00:14:07 +0300
Subject: Added "--clip-models-path" switch to avoid using default
"~/.cache/clip" and enable to run under unprivileged user without homedir
---
modules/interrogate.py | 4 ++--
modules/shared.py | 1 +
2 files changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 65b05d34..9769aa34 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -56,9 +56,9 @@ class InterrogateModels:
import clip
if self.running_on_cpu:
- model, preprocess = clip.load(clip_model_name, device="cpu")
+ model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
else:
- model, preprocess = clip.load(clip_model_name)
+ model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
model.eval()
model = model.to(devices.device_interrogate)
diff --git a/modules/shared.py b/modules/shared.py
index e4f163c1..36212031 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -51,6 +51,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
+parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
--
cgit v1.2.3
From d9e4e4d7a09d4aee8ce249a3c8e91ce165b10fa5 Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Sun, 30 Oct 2022 15:33:02 -0700
Subject: Fix non-square full resolution inpainting.
---
modules/masking.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/masking.py b/modules/masking.py
index fd8d9241..a5c4d2da 100644
--- a/modules/masking.py
+++ b/modules/masking.py
@@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w
ratio_processing = processing_width / processing_height
if ratio_crop_region > ratio_processing:
- desired_height = (x2 - x1) * ratio_processing
+ desired_height = (x2 - x1) / ratio_processing
desired_height_diff = int(desired_height - (y2-y1))
y1 -= desired_height_diff//2
y2 += desired_height_diff - desired_height_diff//2
--
cgit v1.2.3
From 21fba39c609859a60616420afda3b34a89e00761 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 30 Oct 2022 23:45:52 +0000
Subject: Add callbacks and param objects
---
modules/script_callbacks.py | 27 +++++++++++++++++++++++++++
1 file changed, 27 insertions(+)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 6ea58d61..a206ea59 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -24,12 +24,22 @@ class ImageSaveParams:
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
+class CGFDenoiserParams:
+ def __init__(self, x_in, image_cond_in, sigma_in, sampling_step, total_sampling_steps):
+ self.x_in = x_in
+ self.image_cond_in = image_cond_in
+ self.sigma_in = sigma_in
+ self.sampling_step = sampling_step
+ self.total_sampling_steps = total_sampling_steps
+
+
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
callbacks_before_image_saved = []
callbacks_image_saved = []
+callbacks_cfg_denoiser = []
def clear_callbacks():
@@ -84,6 +94,14 @@ def image_saved_callback(params: ImageSaveParams):
report_exception(c, 'image_saved_callback')
+def cfg_denoiser_callback(params: CGFDenoiserParams):
+ for c in callbacks_cfg_denoiser:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_denoiser_callback')
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -130,3 +148,12 @@ def on_image_saved(callback):
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
add_callback(callbacks_image_saved, callback)
+
+
+def on_cfg_denoiser(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
+ The callback is called with one argument:
+ - params: CGFDenoiserParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callbacks_cfg_denoiser, callback)
+
--
cgit v1.2.3
From 8906be85ac91310b37dccddc44f23631eb7a15f5 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 30 Oct 2022 23:47:08 +0000
Subject: add callback cleardown
---
modules/script_callbacks.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index a206ea59..b0b8dc47 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -48,7 +48,7 @@ def clear_callbacks():
callbacks_ui_settings.clear()
callbacks_before_image_saved.clear()
callbacks_image_saved.clear()
-
+ callbacks_cfg_denoiser.clear()
def model_loaded_callback(sd_model):
for c in callbacks_model_loaded:
--
cgit v1.2.3
From 8ae0ea9deaa5a09d1e0aa8b2f8e97c38d71cdbda Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 30 Oct 2022 23:48:33 +0000
Subject: Add callback to sd_samplers
---
modules/sd_samplers.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 3670b57d..30cb5c4b 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -11,6 +11,7 @@ from modules import prompt_parser, devices, processing, images
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
+from modules.script_callbacks import CGFDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -278,6 +279,8 @@ class CFGDenoiser(torch.nn.Module):
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ cfg_denoiser_callback(CGFDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps))
+
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
--
cgit v1.2.3
From 081df45da47feadfb055552c0ad5c6e6ecdd9f28 Mon Sep 17 00:00:00 2001
From: Maiko Sinkyaet Tan
Date: Mon, 31 Oct 2022 08:47:43 +0800
Subject: docs: add python doc (?)
not sure if this available...
---
modules/script_callbacks.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index f5509629..d8aa6f00 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -103,6 +103,8 @@ def add_callback(callbacks, fun):
def on_app_started(callback):
+ """register a function to be called when the webui started, the gradio `Block` component and
+ fastapi `FastAPI` object are passed as the arguments"""
add_callback(callbacks_app_started, callback)
--
cgit v1.2.3
From adaa699e3888e1396162083d65c63cd6774cc6b0 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sun, 30 Oct 2022 18:08:40 +0800
Subject: prototype interrupt api
---
modules/api/api.py | 6 ++++++
1 file changed, 6 insertions(+)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 6c06d449..e702c9c0 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -40,6 +40,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
+ self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -176,6 +177,11 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
+ def interruptapi(self):
+ shared.state.interrupt()
+
+ return {}
+
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)
--
cgit v1.2.3
From 29f758afe9c09834a1cd2dac3f5dd6fbf7dfeaa7 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Mon, 31 Oct 2022 02:39:55 +0000
Subject: Extend extras image cache with upscale_first arg
---
modules/extras.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index 4d51088b..8e2ab35c 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
info_hash=hash(info),
- args_hash=hash(upscale_args))
+ args_hash=hash((upscale_args, upscale_first)))
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
--
cgit v1.2.3
From b96d0c4e9ecec3c856b9b4ec795dbd0d34fcac51 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Mon, 31 Oct 2022 14:42:28 +0700
Subject: Fix typo from previous commit
---
modules/sd_vae.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 0767b925..2ce44d5f 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -27,8 +27,8 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True)
- *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
+ *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
+ *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
--
cgit v1.2.3
From 726769da35970f4c100fa7edf11850f9dc059c41 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Mon, 31 Oct 2022 15:19:34 +0700
Subject: Checkpoint cache by combination key of checkpoint and vae
---
modules/sd_models.py | 27 ++++++++++++++++-----------
modules/sd_vae.py | 8 +++++++-
2 files changed, 23 insertions(+), 12 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 91ad4b5e..850f7b7b 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -160,11 +160,15 @@ def get_state_dict_from_checkpoint(pl_sd):
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
-def load_model_weights(model, checkpoint_info, force=False):
+def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
- if force or checkpoint_info not in checkpoints_loaded:
+ vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
+
+ checkpoint_key = (checkpoint_info, vae_file)
+
+ if checkpoint_key not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
@@ -185,24 +189,25 @@ def load_model_weights(model, checkpoint_info, force=False):
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
- sd_vae.load_vae(model, checkpoint_file)
+ sd_vae.load_vae(model, vae_file)
model.first_stage_model.to(devices.dtype_vae)
if shared.opts.sd_checkpoint_cache > 0:
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
else:
- print(f"Loading weights [{sd_model_hash}] from cache")
- checkpoints_loaded.move_to_end(checkpoint_info)
- model.load_state_dict(checkpoints_loaded[checkpoint_info])
+ vae_name = sd_vae.get_filename(vae_file)
+ print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache")
+ checkpoints_loaded.move_to_end(checkpoint_key)
+ model.load_state_dict(checkpoints_loaded[checkpoint_key])
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
-def load_model(checkpoint_info=None, force=False):
+def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -223,7 +228,7 @@ def load_model(checkpoint_info=None, force=False):
do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model)
- load_model_weights(sd_model, checkpoint_info, force=force)
+ load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
@@ -250,7 +255,7 @@ def reload_model_weights(sd_model, info=None, force=False):
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
- load_model(checkpoint_info, force=force)
+ load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -260,7 +265,7 @@ def reload_model_weights(sd_model, info=None, force=False):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_model_weights(sd_model, checkpoint_info, force=force)
+ load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 2ce44d5f..e9239326 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -43,7 +43,7 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
vae_dict.update(res)
return vae_list
-def load_vae(model, checkpoint_file, vae_file="auto"):
+def resolve_vae(checkpoint_file, vae_file="auto"):
global first_load, vae_dict, vae_list
# save_settings = False
@@ -94,6 +94,12 @@ def load_vae(model, checkpoint_file, vae_file="auto"):
if vae_file and not os.path.exists(vae_file):
vae_file = None
+ return vae_file
+
+def load_vae(model, vae_file):
+ global first_load, vae_dict, vae_list
+ # save_settings = False
+
if vae_file:
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
--
cgit v1.2.3
From 36966e3200943dbf890b5338cfa939df552d3c47 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Mon, 31 Oct 2022 15:38:58 +0700
Subject: Fix #4035
---
modules/sd_models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f86dc3ed..a29c8c1a 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -201,7 +201,7 @@ def load_model_weights(model, checkpoint_info):
if shared.opts.sd_checkpoint_cache > 0:
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1:
checkpoints_loaded.popitem(last=False) # LRU
else:
print(f"Loading weights [{sd_model_hash}] from cache")
--
cgit v1.2.3
From bf7a699845675eefdabb9cfa40c55398976274ae Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Mon, 31 Oct 2022 16:27:27 +0700
Subject: Fix #4035 for real now
---
modules/sd_models.py | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index a29c8c1a..b2dd005a 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -165,6 +165,9 @@ def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
+ if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"):
+ checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
+
if checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
@@ -198,16 +201,14 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
-
- if shared.opts.sd_checkpoint_cache > 0:
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1:
- checkpoints_loaded.popitem(last=False) # LRU
else:
print(f"Loading weights [{sd_model_hash}] from cache")
- checkpoints_loaded.move_to_end(checkpoint_info)
model.load_state_dict(checkpoints_loaded[checkpoint_info])
+ if shared.opts.sd_checkpoint_cache > 0:
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ checkpoints_loaded.popitem(last=False) # LRU
+
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
--
cgit v1.2.3
From 006756f9cd6258eae418e9209cfc13f940ec53e1 Mon Sep 17 00:00:00 2001
From: Fampai <>
Date: Mon, 31 Oct 2022 07:26:08 -0400
Subject: Added TI training optimizations
option to use xattention optimizations when training
option to unload vae when training
---
modules/shared.py | 3 ++-
modules/textual_inversion/textual_inversion.py | 9 +++++++++
modules/textual_inversion/ui.py | 7 +++++--
3 files changed, 16 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index fb84afd8..4c3d0ce7 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -256,11 +256,12 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('training', "Training"), {
- "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
+ "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
+ "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 17dfb223..b0a1d26b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -214,6 +214,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
+ unload = shared.opts.unload_models_when_training
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
@@ -238,6 +239,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
hijack = sd_hijack.model_hijack
@@ -303,6 +306,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
+
+ shared.sd_model.first_stage_model.to(devices.device)
+
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
@@ -330,6 +336,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
processed = processing.process_images(p)
image = processed.images[0]
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index e712284d..d679e6f4 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -25,8 +25,10 @@ def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
+ apply_optimizations = shared.opts.training_xattention_optimizations
try:
- sd_hijack.undo_optimizations()
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
except Exception:
raise
finally:
- sd_hijack.apply_optimizations()
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()
--
cgit v1.2.3
From 890e68aaf75ae80d5eb2fa95b4bf1adf78b96881 Mon Sep 17 00:00:00 2001
From: Fampai <>
Date: Mon, 31 Oct 2022 10:07:12 -0400
Subject: Fixed minor bug
when unloading vae during TI training, generating images after
training will error out
---
modules/textual_inversion/textual_inversion.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 54a734f1..0aeb0459 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -409,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename
--
cgit v1.2.3
From 910a097ae2ed78a62101951f1b87137f9e1baaea Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 31 Oct 2022 17:36:45 +0300
Subject: add initial version of the extensions tab fix broken Restart Gradio
button
---
modules/extensions.py | 83 +++++++++++++++
modules/generation_parameters_copypaste.py | 5 +
modules/scripts.py | 21 +---
modules/shared.py | 10 +-
modules/ui.py | 16 ++-
modules/ui_extensions.py | 162 +++++++++++++++++++++++++++++
6 files changed, 274 insertions(+), 23 deletions(-)
create mode 100644 modules/extensions.py
create mode 100644 modules/ui_extensions.py
(limited to 'modules')
diff --git a/modules/extensions.py b/modules/extensions.py
new file mode 100644
index 00000000..8d6ae848
--- /dev/null
+++ b/modules/extensions.py
@@ -0,0 +1,83 @@
+import os
+import sys
+import traceback
+
+import git
+
+from modules import paths, shared
+
+
+extensions = []
+extensions_dir = os.path.join(paths.script_path, "extensions")
+
+
+def active():
+ return [x for x in extensions if x.enabled]
+
+
+class Extension:
+ def __init__(self, name, path, enabled=True):
+ self.name = name
+ self.path = path
+ self.enabled = enabled
+ self.status = ''
+ self.can_update = False
+
+ repo = None
+ try:
+ if os.path.exists(os.path.join(path, ".git")):
+ repo = git.Repo(path)
+ except Exception:
+ print(f"Error reading github repository info from {path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ if repo is None or repo.bare:
+ self.remote = None
+ else:
+ self.remote = next(repo.remote().urls, None)
+ self.status = 'unknown'
+
+ def list_files(self, subdir, extension):
+ from modules import scripts
+
+ dirpath = os.path.join(self.path, subdir)
+ if not os.path.isdir(dirpath):
+ return []
+
+ res = []
+ for filename in sorted(os.listdir(dirpath)):
+ res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename)))
+
+ res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+
+ return res
+
+ def check_updates(self):
+ repo = git.Repo(self.path)
+ for fetch in repo.remote().fetch("--dry-run"):
+ if fetch.flags != fetch.HEAD_UPTODATE:
+ self.can_update = True
+ self.status = "behind"
+ return
+
+ self.can_update = False
+ self.status = "latest"
+
+ def pull(self):
+ repo = git.Repo(self.path)
+ repo.remotes.origin.pull()
+
+
+def list_extensions():
+ extensions.clear()
+
+ if not os.path.isdir(extensions_dir):
+ return
+
+ for dirname in sorted(os.listdir(extensions_dir)):
+ path = os.path.join(extensions_dir, dirname)
+ if not os.path.isdir(path):
+ continue
+
+ extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
+ extensions.append(extension)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index df70c728..985ec95e 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -17,6 +17,11 @@ paste_fields = {}
bind_list = []
+def reset():
+ paste_fields.clear()
+ bind_list.clear()
+
+
def quote(text):
if ',' not in str(text):
return text
diff --git a/modules/scripts.py b/modules/scripts.py
index 96e44bfd..533db45c 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -7,7 +7,7 @@ import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared, paths, script_callbacks
+from modules import shared, paths, script_callbacks, extensions
AlwaysVisible = object()
@@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension):
for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
- extdir = os.path.join(paths.script_path, "extensions")
- if os.path.exists(extdir):
- for dirname in sorted(os.listdir(extdir)):
- dirpath = os.path.join(extdir, dirname)
- scriptdirpath = os.path.join(dirpath, scriptdirname)
-
- if not os.path.isdir(scriptdirpath):
- continue
-
- for filename in sorted(os.listdir(scriptdirpath)):
- scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
+ for ext in extensions.active():
+ scripts_list += ext.list_files(scriptdirname, extension)
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
@@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension):
def list_files_with_name(filename):
res = []
- dirs = [paths.script_path]
-
- extdir = os.path.join(paths.script_path, "extensions")
- if os.path.exists(extdir):
- dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
+ dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
for dirpath in dirs:
if not os.path.isdir(dirpath):
diff --git a/modules/shared.py b/modules/shared.py
index e4f163c1..cce87081 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -132,6 +132,7 @@ class State:
current_image = None
current_image_sampling_step = 0
textinfo = None
+ need_restart = False
def skip(self):
self.skipped = True
@@ -354,6 +355,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
+options_templates.update(options_section((None, "Hidden options"), {
+ "disabled_extensions": OptionInfo([], "Disable those extensions"),
+}))
+
+options_templates.update()
+
class Options:
data = None
@@ -365,8 +372,9 @@ class Options:
def __setattr__(self, key, value):
if self.data is not None:
- if key in self.data:
+ if key in self.data or key in self.data_labels:
self.data[key] = value
+ return
return super(Options, self).__setattr__(key, value)
diff --git a/modules/ui.py b/modules/ui.py
index 5055ca64..2c15abb7 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -19,7 +19,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
-from modules import sd_hijack, sd_models, localization, script_callbacks
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions
from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts
@@ -671,6 +671,7 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
+ parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
@@ -1511,8 +1512,9 @@ def create_ui(wrap_gradio_gpu_call):
column = None
with gr.Row(elem_id="settings").style(equal_height=False):
for i, (k, item) in enumerate(opts.data_labels.items()):
+ section_must_be_skipped = item.section[0] is None
- if previous_section != item.section:
+ if previous_section != item.section and not section_must_be_skipped:
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
if column is not None:
column.__exit__()
@@ -1531,6 +1533,8 @@ def create_ui(wrap_gradio_gpu_call):
if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
+ elif section_must_be_skipped:
+ components.append(dummy_component)
else:
component = create_setting_component(k)
component_dict[k] = component
@@ -1572,9 +1576,10 @@ def create_ui(wrap_gradio_gpu_call):
def request_restart():
shared.state.interrupt()
- settings_interface.gradio_ref.do_restart = True
+ shared.state.need_restart = True
restart_gradio.click(
+
fn=request_restart,
inputs=[],
outputs=[],
@@ -1612,14 +1617,15 @@ def create_ui(wrap_gradio_gpu_call):
interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")]
+ extensions_interface = ui_extensions.create_ui()
+ interfaces += [(extensions_interface, "Extensions", "extensions")]
+
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
- settings_interface.gradio_ref = demo
-
parameters_copypaste.integrate_settings_paste_fields(component_dict)
parameters_copypaste.run_bind()
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
new file mode 100644
index 00000000..b7d747dc
--- /dev/null
+++ b/modules/ui_extensions.py
@@ -0,0 +1,162 @@
+import json
+import os.path
+import shutil
+import sys
+import time
+import traceback
+
+import git
+
+import gradio as gr
+import html
+
+from modules import extensions, shared, paths
+
+
+def apply_and_restart(disable_list, update_list):
+ disabled = json.loads(disable_list)
+ assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
+
+ update = json.loads(update_list)
+ assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
+
+ update = set(update)
+
+ for ext in extensions.extensions:
+ if ext.name not in update:
+ continue
+
+ try:
+ ext.pull()
+ except Exception:
+ print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ shared.opts.disabled_extensions = disabled
+ shared.opts.save(shared.config_filename)
+
+ shared.state.interrupt()
+ shared.state.need_restart = True
+
+
+def check_updates():
+ for ext in extensions.extensions:
+ if ext.remote is None:
+ continue
+
+ try:
+ ext.check_updates()
+ except Exception:
+ print(f"Error checking updates for {ext.name}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ return extension_table()
+
+
+def extension_table():
+ code = f"""
+
+ """
+
+ return code
+
+
+def install_extension_from_url(dirname, url):
+ assert url, 'No URL specified'
+
+ if dirname is None or dirname == "":
+ *parts, last_part = url.split('/')
+ last_part = last_part.replace(".git", "")
+
+ dirname = last_part
+
+ target_dir = os.path.join(extensions.extensions_dir, dirname)
+ assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
+
+ assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed'
+
+ tmpdir = os.path.join(paths.script_path, "tmp", dirname)
+
+ try:
+ shutil.rmtree(tmpdir, True)
+
+ repo = git.Repo.clone_from(url, tmpdir)
+ repo.remote().fetch()
+
+ os.rename(tmpdir, target_dir)
+
+ extensions.list_extensions()
+ return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
+ finally:
+ shutil.rmtree(tmpdir, True)
+
+
+def create_ui():
+ import modules.ui
+
+ with gr.Blocks(analytics_enabled=False) as ui:
+ with gr.Tabs(elem_id="tabs_extensions") as tabs:
+ with gr.TabItem("Installed"):
+ extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False)
+ extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False)
+
+ with gr.Row():
+ apply = gr.Button(value="Apply and restart UI", variant="primary")
+ check = gr.Button(value="Check for updates")
+
+ extensions_table = gr.HTML(lambda: extension_table())
+
+ apply.click(
+ fn=apply_and_restart,
+ _js="extensions_apply",
+ inputs=[extensions_disabled_list, extensions_update_list],
+ outputs=[],
+ )
+
+ check.click(
+ fn=check_updates,
+ _js="extensions_check",
+ inputs=[],
+ outputs=[extensions_table],
+ )
+
+ with gr.TabItem("Install from URL"):
+ install_url = gr.Text(label="URL for extension's git repository")
+ install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
+ intall_button = gr.Button(value="Install", variant="primary")
+ intall_result = gr.HTML(elem_id="extension_install_result")
+
+ intall_button.click(
+ fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
+ inputs=[install_dirname, install_url],
+ outputs=[extensions_table, intall_result],
+ )
+
+ return ui
--
cgit v1.2.3
From dc7425a56e7a014cbfa3b3d44ad2321e519fe378 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 31 Oct 2022 18:33:44 +0300
Subject: disable access to extension stuff for non-local servers
---
modules/shared.py | 5 ++++-
modules/ui_extensions.py | 10 ++++++++++
2 files changed, 14 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index cce87081..a27c654e 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -40,7 +40,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
-parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
+parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
@@ -97,6 +97,9 @@ restricted_opts = {
"outdir_save",
}
+if cmd_opts.share or cmd_opts.listen:
+ cmd_opts.disable_extension_access = True
+
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index b7d747dc..e74b7d68 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -13,7 +13,13 @@ import html
from modules import extensions, shared, paths
+def check_access():
+ assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
+
+
def apply_and_restart(disable_list, update_list):
+ check_access()
+
disabled = json.loads(disable_list)
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
@@ -40,6 +46,8 @@ def apply_and_restart(disable_list, update_list):
def check_updates():
+ check_access()
+
for ext in extensions.extensions:
if ext.remote is None:
continue
@@ -89,6 +97,8 @@ def extension_table():
def install_extension_from_url(dirname, url):
+ check_access()
+
assert url, 'No URL specified'
if dirname is None or dirname == "":
--
cgit v1.2.3
From 58cc03edd0fe8c7e64297bcfe51111caaafabfd7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 31 Oct 2022 18:40:47 +0300
Subject: fix scripts I broke with the extension tab changes
---
modules/extensions.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/extensions.py b/modules/extensions.py
index 8d6ae848..897af96e 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -46,7 +46,7 @@ class Extension:
res = []
for filename in sorted(os.listdir(dirpath)):
- res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename)))
+ res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
--
cgit v1.2.3
From 9e22a357545c0395c81dd800c72fa18f350545ec Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 31 Oct 2022 18:45:50 +0300
Subject: fix the error with extension tab not working because of the previous
commit
---
modules/shared.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index a27c654e..c83fb9f5 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -97,8 +97,7 @@ restricted_opts = {
"outdir_save",
}
-if cmd_opts.share or cmd_opts.listen:
- cmd_opts.disable_extension_access = True
+cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
--
cgit v1.2.3
From df6a7ebfe8cc4da23861e3e2583693bb7808d573 Mon Sep 17 00:00:00 2001
From: Roy Shilkrot
Date: Mon, 31 Oct 2022 11:50:33 -0400
Subject: revert things to master
---
modules/api/api.py | 2 --
modules/api/models.py | 6 +-----
2 files changed, 1 insertion(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index c510a833..6a903e4c 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -117,8 +117,6 @@ class Api:
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
- def extrasapi(self):
- raise NotImplementedError
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
reqDict = setUpscalers(req)
diff --git a/modules/api/models.py b/modules/api/models.py
index 035a7179..82ab29b8 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -64,11 +64,7 @@ class PydanticModelGenerator:
self._model_name = model_name
-
- if class_instance is not None:
- self._class_data = merge_class_params(class_instance)
- else:
- self._class_data = {}
+ self._class_data = merge_class_params(class_instance)
self._model_def = [
ModelDef(
--
cgit v1.2.3
From 3f3d14afd5abd07d3843370dc1c28be299dbdbab Mon Sep 17 00:00:00 2001
From: Roy Shilkrot
Date: Mon, 31 Oct 2022 11:51:21 -0400
Subject: nix unused thing
---
modules/api/api.py | 4 ----
1 file changed, 4 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 6a903e4c..536e3f16 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -182,10 +182,6 @@ class Api:
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
- populate = interrogatereq.copy(update={ # Override __init__ params
- }
- )
-
img = self.__base64_to_image(image_b64)
# Override object param
--
cgit v1.2.3
From 8792be50078c2e89ac2fa8ff4e4c1ca82f140d45 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Mon, 31 Oct 2022 17:29:04 -0700
Subject: Add PNG info to pngs only if option is enabled.
---
modules/images.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index a0728553..ae705cbd 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -510,8 +510,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if extension.lower() == '.png':
pnginfo_data = PngImagePlugin.PngInfo()
- for k, v in params.pnginfo.items():
- pnginfo_data.add_text(k, str(v))
+ if opts.enable_pnginfo:
+ for k, v in params.pnginfo.items():
+ pnginfo_data.add_text(k, str(v))
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
--
cgit v1.2.3
From 525c1edf431d9c9efc1349be8657f0301299e3bc Mon Sep 17 00:00:00 2001
From: k_sugawara
Date: Tue, 1 Nov 2022 09:40:54 +0900
Subject: make save dir if save dir is not exists
---
modules/img2img.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/img2img.py b/modules/img2img.py
index efda26e1..35c5df9b 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -55,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args):
filename = f"{left}-{n}{right}"
if not save_normally:
+ os.makedirs(output_dir, exist_ok=True)
processed_image.save(os.path.join(output_dir, filename))
--
cgit v1.2.3
From 5b0f624bdc1335313258f59a37607e699e800c22 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 1 Nov 2022 09:59:00 +0300
Subject: Added Available tab to extensions UI.
---
modules/ui_extensions.py | 112 +++++++++++++++++++++++++++++++++++++++++++----
1 file changed, 104 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index e74b7d68..ab807722 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -13,6 +13,9 @@ import html
from modules import extensions, shared, paths
+available_extensions = {"extensions": []}
+
+
def check_access():
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
@@ -96,6 +99,14 @@ def extension_table():
return code
+def normalize_git_url(url):
+ if url is None:
+ return ""
+
+ url = url.replace(".git", "")
+ return url
+
+
def install_extension_from_url(dirname, url):
check_access()
@@ -103,14 +114,15 @@ def install_extension_from_url(dirname, url):
if dirname is None or dirname == "":
*parts, last_part = url.split('/')
- last_part = last_part.replace(".git", "")
+ last_part = normalize_git_url(last_part)
dirname = last_part
target_dir = os.path.join(extensions.extensions_dir, dirname)
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
- assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed'
+ normalized_url = normalize_git_url(url)
+ assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
@@ -128,18 +140,80 @@ def install_extension_from_url(dirname, url):
shutil.rmtree(tmpdir, True)
+def install_extension_from_index(url):
+ ext_table, message = install_extension_from_url(None, url)
+
+ return refresh_available_extensions_from_data(), ext_table, message
+
+
+def refresh_available_extensions(url):
+ global available_extensions
+
+ import urllib.request
+ with urllib.request.urlopen(url) as response:
+ text = response.read()
+
+ available_extensions = json.loads(text)
+
+ return url, refresh_available_extensions_from_data(), ''
+
+
+def refresh_available_extensions_from_data():
+ extlist = available_extensions["extensions"]
+ installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
+
+ code = f"""
+
+ """
+
+ return code
+
+
def create_ui():
import modules.ui
with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"):
- extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False)
- extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False)
with gr.Row():
apply = gr.Button(value="Apply and restart UI", variant="primary")
check = gr.Button(value="Check for updates")
+ extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
+ extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
extensions_table = gr.HTML(lambda: extension_table())
@@ -157,16 +231,38 @@ def create_ui():
outputs=[extensions_table],
)
+ with gr.TabItem("Available"):
+ with gr.Row():
+ refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
+ available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
+ extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
+ install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
+
+ install_result = gr.HTML()
+ available_extensions_table = gr.HTML()
+
+ refresh_available_extensions_button.click(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
+ inputs=[available_extensions_index],
+ outputs=[available_extensions_index, available_extensions_table, install_result],
+ )
+
+ install_extension_button.click(
+ fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
+ inputs=[extension_to_install],
+ outputs=[available_extensions_table, extensions_table, install_result],
+ )
+
with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
- intall_button = gr.Button(value="Install", variant="primary")
- intall_result = gr.HTML(elem_id="extension_install_result")
+ install_button = gr.Button(value="Install", variant="primary")
+ install_result = gr.HTML(elem_id="extension_install_result")
- intall_button.click(
+ install_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
inputs=[install_dirname, install_url],
- outputs=[extensions_table, intall_result],
+ outputs=[extensions_table, install_result],
)
return ui
--
cgit v1.2.3
From af758e97fa2c4c853042f121af4e974be01e6696 Mon Sep 17 00:00:00 2001
From: Jairo Correa
Date: Tue, 1 Nov 2022 04:01:49 -0300
Subject: Unload sd_model before loading the other
---
modules/lowvram.py | 21 +++++++++++++--------
modules/processing.py | 3 +++
modules/sd_hijack.py | 4 ++++
modules/sd_models.py | 14 +++++++++++++-
4 files changed, 33 insertions(+), 9 deletions(-)
(limited to 'modules')
diff --git a/modules/lowvram.py b/modules/lowvram.py
index f327c3df..a4652cb1 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
- def first_stage_model_encode_wrap(self, encoder, x):
- send_me_to_gpu(self, None)
- return encoder(x)
- def first_stage_model_decode_wrap(self, decoder, z):
- send_me_to_gpu(self, None)
- return decoder(z)
+ first_stage_model = sd_model.first_stage_model
+ first_stage_model_encode = sd_model.first_stage_model.encode
+ first_stage_model_decode = sd_model.first_stage_model.decode
+
+ def first_stage_model_encode_wrap(x):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_encode(x)
+
+ def first_stage_model_decode_wrap(z):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
# register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
- sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
- sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
diff --git a/modules/processing.py b/modules/processing.py
index b1df4918..57d3a523 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.postprocess(p, res)
+ p.sd_model = None
+ p.sampler = None
+
return res
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 0f10828e..bc49d235 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
+ self.layers = None
+ self.circular_enabled = False
+ self.clip = None
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f86dc3ed..90007da3 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,6 +1,7 @@
import collections
import os.path
import sys
+import gc
from collections import namedtuple
import torch
import re
@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None):
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
+ if shared.sd_model:
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
+ shared.sd_model = None
+ gc.collect()
+ devices.torch_gc()
+
sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None):
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
+
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None):
return sd_model
-def reload_model_weights(sd_model, info=None):
+def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
+ if not sd_model:
+ sd_model = shared.sd_model
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
return shared.sd_model
--
cgit v1.2.3
From d35bf649456da2558cbb6f2ea16fa1606022b7e7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 1 Nov 2022 14:19:24 +0300
Subject: make launch.py run installers for extensions that have ones add some
more classes to safety module for an extension
---
modules/safe.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/safe.py b/modules/safe.py
index 399165a1..348a24fc 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
return getattr(torch._utils, name)
- if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
--
cgit v1.2.3
From 467cae167a3066ffa2b2a5e6f16dd42642219aba Mon Sep 17 00:00:00 2001
From: TinkTheBoush
Date: Tue, 1 Nov 2022 23:29:12 +0900
Subject: append_tag_shuffle
---
modules/hypernetworks/hypernetwork.py | 4 ++--
modules/textual_inversion/dataset.py | 10 ++++++++--
modules/textual_inversion/textual_inversion.py | 4 ++--
modules/ui.py | 3 +++
4 files changed, 15 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index a11e01d6..7630fb81 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -331,7 +331,7 @@ def report_statistics(loss_info:dict):
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -376,7 +376,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index ad726577..e9d97cc1 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -24,7 +24,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
@@ -33,6 +33,7 @@ class PersonalizedBase(Dataset):
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+ self.shuffle_tags = shuffle_tags
self.dataset = []
@@ -98,7 +99,12 @@ class PersonalizedBase(Dataset):
def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", filename_text)
+ if self.tag_shuffle:
+ tags = filename_text.split(',')
+ random.shuffle(tags)
+ text = text.replace("[filewords]", ','.join(tags))
+ else:
+ text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index e0babb46..64700e23 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -224,7 +224,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
@@ -271,7 +271,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
diff --git a/modules/ui.py b/modules/ui.py
index 2c15abb7..ad383979 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1267,6 +1267,7 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
+ shuffle_tags = gr.Checkbox(label='Shuffleing tags by "," when create texts', value=True)
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
@@ -1361,6 +1362,7 @@ def create_ui(wrap_gradio_gpu_call):
template_file,
save_image_with_stored_embedding,
preview_from_txt2img,
+ shuffle_tags,
*txt2img_preview_params,
],
outputs=[
@@ -1385,6 +1387,7 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every,
template_file,
preview_from_txt2img,
+ shuffle_tags,
*txt2img_preview_params,
],
outputs=[
--
cgit v1.2.3
From bc607686065b8c7751d1af7c05b960378fa256de Mon Sep 17 00:00:00 2001
From: Billy Cao
Date: Tue, 1 Nov 2022 23:26:55 +0800
Subject: Enable override_settings to take effect for hypernetworks
---
modules/processing.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 57d3a523..86d015af 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -422,13 +422,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
- opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
+ opts.data[k] = v # we don't call onchange for simplicity which makes changing model impossible
+ if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not
res = process_images_inner(p)
- finally:
+ finally: # restore opts to original state
for k, v in stored_opts.items():
opts.data[k] = v
+ if k == 'sd_hypernetwork': shared.reload_hypernetworks()
return res
--
cgit v1.2.3
From 198a1ffcfc963a3d74674fad560e87dbebf7949f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 1 Nov 2022 19:13:59 +0300
Subject: fix API returning extra stuff in base64 encoded iamges for #3972
---
modules/api/api.py | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index e702c9c0..bb87d795 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,6 +1,8 @@
+import base64
+import io
import time
import uvicorn
-from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
+from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
from modules import devices
@@ -29,6 +31,12 @@ def setUpscalers(req: dict):
return reqDict
+def encode_pil_to_base64(image):
+ buffer = io.BytesIO()
+ image.save(buffer, format="png")
+ return base64.b64encode(buffer.getvalue())
+
+
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
--
cgit v1.2.3
From 401350cd59555439570ba5bc95f0ac5698e372e4 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Tue, 1 Nov 2022 14:03:56 -0500
Subject: clear on the client-side again
---
modules/ui.py | 9 +--------
1 file changed, 1 insertion(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 447722cd..f43e79ab 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -400,19 +400,12 @@ def create_seed_inputs():
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
-def clear_prompt(prompt, _prompt_neg, confirmed, _token_counter):
- """Given confirmation from a user on the client-side, go ahead with clearing prompt"""
- if confirmed:
- return ["", "", confirmed, update_token_counter("", 1)]
- else:
- return [prompt, _prompt_neg, confirmed, _token_counter]
-
def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter):
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
button.click(
_js="clear_prompt",
- fn=clear_prompt,
+ fn=None,
inputs=[prompt, prompt_neg, _dummy_confirmed, token_counter],
outputs=[prompt, prompt_neg, _dummy_confirmed, token_counter],
)
--
cgit v1.2.3
From 1dd5d6bafad7575f347056a29636cbab71c1c468 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Tue, 1 Nov 2022 14:33:55 -0500
Subject: clean py func defs
---
modules/ui.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index f43e79ab..8a1f3887 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -401,13 +401,13 @@ def create_seed_inputs():
-def connect_clear_prompt(button, prompt, prompt_neg, _dummy_confirmed, token_counter):
+def connect_clear_prompt(button):
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
button.click(
_js="clear_prompt",
fn=None,
- inputs=[prompt, prompt_neg, _dummy_confirmed, token_counter],
- outputs=[prompt, prompt_neg, _dummy_confirmed, token_counter],
+ inputs=[],
+ outputs=[],
)
@@ -746,7 +746,7 @@ def create_ui(wrap_gradio_gpu_call):
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
- connect_clear_prompt(clear_prompt_button, txt2img_prompt, txt2img_negative_prompt, dummy_component, token_counter)
+ connect_clear_prompt(clear_prompt_button)
txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
@@ -929,7 +929,7 @@ def create_ui(wrap_gradio_gpu_call):
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
- connect_clear_prompt(clear_prompt_button, img2img_prompt, img2img_negative_prompt, dummy_component, token_counter)
+ connect_clear_prompt(clear_prompt_button)
img2img_prompt_img.change(
fn=modules.images.image_data,
--
cgit v1.2.3
From 86d35526a13a0e2432ab71d1d40b191615d3e343 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Tue, 1 Nov 2022 14:53:40 -0500
Subject: make line evil again
---
modules/ui.py | 8 ++------
1 file changed, 2 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 8a1f3887..bd67c1bd 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -692,9 +692,7 @@ def create_ui(wrap_gradio_gpu_call):
parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\
- txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\
- token_button, clear_prompt_button = create_toprow(is_img2img=False)
+ txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -850,9 +848,7 @@ def create_ui(wrap_gradio_gpu_call):
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\
- img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\
- token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=True)
+ img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'):
--
cgit v1.2.3
From cffc240a7327ae60671ff533469fc4ed4bf605de Mon Sep 17 00:00:00 2001
From: Nerogar
Date: Sun, 23 Oct 2022 14:05:25 +0200
Subject: fixed textual inversion training with inpainting models
---
modules/textual_inversion/textual_inversion.py | 27 +++++++++++++++++++++++++-
1 file changed, 26 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 0aeb0459..2630c7c9 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -224,6 +224,26 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
+def create_dummy_mask(x, width=None, height=None):
+ if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ else:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ return image_conditioning
+
+
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
@@ -286,6 +306,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
forced_filename = ""
embedding_yet_to_be_embedded = False
+ img_c = None
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
@@ -299,8 +320,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
with torch.autocast("cuda"):
c = cond_model([entry.cond_text for entry in entries])
+ if img_c is None:
+ img_c = create_dummy_mask(c, training_width, training_height)
+
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ loss = shared.sd_model(x, cond)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
--
cgit v1.2.3
From cd88e21dc5d5cfdfbd408454acd259b7db9d0ec8 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 2 Nov 2022 00:34:58 +0000
Subject: Class Name typo and add descriptions to fields.
---
modules/script_callbacks.py | 23 ++++++++++++++++-------
1 file changed, 16 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index b0b8dc47..ff40b056 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -24,13 +24,22 @@ class ImageSaveParams:
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
-class CGFDenoiserParams:
- def __init__(self, x_in, image_cond_in, sigma_in, sampling_step, total_sampling_steps):
- self.x_in = x_in
- self.image_cond_in = image_cond_in
- self.sigma_in = sigma_in
+class CFGDenoiserParams:
+ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.image_cond = image_cond
+ """Conditioning image"""
+
+ self.sigma = sigma
+ """Current sigma noise step value"""
+
self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
@@ -94,7 +103,7 @@ def image_saved_callback(params: ImageSaveParams):
report_exception(c, 'image_saved_callback')
-def cfg_denoiser_callback(params: CGFDenoiserParams):
+def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callbacks_cfg_denoiser:
try:
c.callback(params)
@@ -153,7 +162,7 @@ def on_image_saved(callback):
def on_cfg_denoiser(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument:
- - params: CGFDenoiserParams - parameters to be passed to the inner model and sampling state details.
+ - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
add_callback(callbacks_cfg_denoiser, callback)
--
cgit v1.2.3
From 5b6bedf6f2ebacb7f1f5809af8e26a6a1af16e2a Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 2 Nov 2022 00:38:17 +0000
Subject: Update class name and assign back to vars
---
modules/sd_samplers.py | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 30cb5c4b..ebc0d896 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -11,7 +11,7 @@ from modules import prompt_parser, devices, processing, images
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
-from modules.script_callbacks import CGFDenoiserParams, cfg_denoiser_callback
+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -279,7 +279,11 @@ class CFGDenoiser(torch.nn.Module):
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
- cfg_denoiser_callback(CGFDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps))
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
+ cfg_denoiser_callback(denoiser_params)
+ x_in = denoiser_params.x
+ image_cond_in = denoiser_params.image_cond
+ sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
--
cgit v1.2.3
From c9148b2312b36fee8727f5233da9dbe32aa1f58c Mon Sep 17 00:00:00 2001
From: Jairo Correa
Date: Tue, 1 Nov 2022 21:56:47 -0300
Subject: Release processing resources after it finishes
---
modules/img2img.py | 2 ++
modules/processing.py | 7 ++++---
modules/txt2img.py | 2 ++
3 files changed, 8 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/img2img.py b/modules/img2img.py
index 35c5df9b..fac010aa 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -137,6 +137,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if processed is None:
processed = process_images(p)
+ p.close()
+
shared.total_tqdm.clear()
generation_info_js = processed.js()
diff --git a/modules/processing.py b/modules/processing.py
index 57d3a523..b541ee2b 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -202,6 +202,10 @@ class StableDiffusionProcessing():
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
raise NotImplementedError()
+ def close(self):
+ self.sd_model = None
+ self.sampler = None
+
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
@@ -597,9 +601,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.postprocess(p, res)
- p.sd_model = None
- p.sampler = None
-
return res
diff --git a/modules/txt2img.py b/modules/txt2img.py
index c9d5a090..8e4e8677 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if processed is None:
processed = process_images(p)
+ p.close()
+
shared.total_tqdm.clear()
generation_info_js = processed.js()
--
cgit v1.2.3
From 5510c282b1f1974005790066b5e444f74a5178fb Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 2 Nov 2022 07:26:31 +0300
Subject: fix for extensions' javascript not loading
---
modules/ui.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 2c15abb7..a94f46ea 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -671,6 +671,8 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
+ reload_javascript()
+
parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
@@ -1782,4 +1784,3 @@ def load_javascript(raw_response):
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
-reload_javascript()
--
cgit v1.2.3
From 056f06d3738c267b1014e6e8e1ef5bd97af1fb45 Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Wed, 2 Nov 2022 12:51:46 +0700
Subject: Reload VAE without reloading sd checkpoint
---
modules/sd_models.py | 15 ++++----
modules/sd_vae.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++----
2 files changed, 97 insertions(+), 15 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 6ab85b65..883639d1 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -159,15 +159,13 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
-vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
-
def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
- checkpoint_key = (checkpoint_info, vae_file)
+ checkpoint_key = checkpoint_info
if checkpoint_key not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
@@ -190,13 +188,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
- sd_vae.load_vae(model, vae_file)
- model.first_stage_model.to(devices.dtype_vae)
-
if shared.opts.sd_checkpoint_cache > 0:
+ # if PR #4035 were to get merged, restore base VAE first before caching
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
+
else:
vae_name = sd_vae.get_filename(vae_file)
print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache")
@@ -207,6 +204,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
+ sd_vae.load_vae(model, vae_file)
+
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
@@ -254,14 +253,14 @@ def load_model(checkpoint_info=None):
return sd_model
-def reload_model_weights(sd_model=None, info=None, force=False):
+def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
if not sd_model:
sd_model = shared.sd_model
- if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force:
+ if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index e9239326..78e14e8a 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,26 +1,65 @@
import torch
import os
from collections import namedtuple
-from modules import shared, devices
+from modules import shared, devices, script_callbacks
from modules.paths import models_path
import glob
+
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
vae_dir = "VAE"
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
+
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
+
+
default_vae_dict = {"auto": "auto", "None": "None"}
default_vae_list = ["auto", "None"]
+
+
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
vae_dict = dict(default_vae_dict)
vae_list = list(default_vae_list)
first_load = True
+
+base_vae = None
+loaded_vae_file = None
+checkpoint_info = None
+
+
+def get_base_vae(model):
+ if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
+ return base_vae
+ return None
+
+
+def store_base_vae(model):
+ global base_vae, checkpoint_info
+ if checkpoint_info != model.sd_checkpoint_info:
+ base_vae = model.first_stage_model.state_dict().copy()
+ checkpoint_info = model.sd_checkpoint_info
+
+
+def delete_base_vae():
+ global base_vae, checkpoint_info
+ base_vae = None
+ checkpoint_info = None
+
+
+def restore_base_vae(model):
+ global base_vae, checkpoint_info
+ if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
+ load_vae_dict(model, base_vae)
+ delete_base_vae()
+
+
def get_filename(filepath):
return os.path.splitext(os.path.basename(filepath))[0]
+
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
global vae_dict, vae_list
res = {}
@@ -43,6 +82,7 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
vae_dict.update(res)
return vae_list
+
def resolve_vae(checkpoint_file, vae_file="auto"):
global first_load, vae_dict, vae_list
# save_settings = False
@@ -96,24 +136,26 @@ def resolve_vae(checkpoint_file, vae_file="auto"):
return vae_file
-def load_vae(model, vae_file):
- global first_load, vae_dict, vae_list
+
+def load_vae(model, vae_file=None):
+ global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False
if vae_file:
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
- model.first_stage_model.load_state_dict(vae_dict_1)
+ load_vae_dict(model, vae_dict_1)
- # If vae used is not in dict, update it
- # It will be removed on refresh though
- if vae_file is not None:
+ # If vae used is not in dict, update it
+ # It will be removed on refresh though
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
vae_list.append(vae_opt)
+ loaded_vae_file = vae_file
+
"""
# Save current VAE to VAE settings, maybe? will it work?
if save_settings:
@@ -124,4 +166,45 @@ def load_vae(model, vae_file):
"""
first_load = False
+
+
+# don't call this from outside
+def load_vae_dict(model, vae_dict_1=None):
+ if vae_dict_1:
+ store_base_vae(model)
+ model.first_stage_model.load_state_dict(vae_dict_1)
+ else:
+ restore_base_vae()
model.first_stage_model.to(devices.dtype_vae)
+
+
+def reload_vae_weights(sd_model=None, vae_file="auto"):
+ from modules import lowvram, devices, sd_hijack
+
+ if not sd_model:
+ sd_model = shared.sd_model
+
+ checkpoint_info = sd_model.sd_checkpoint_info
+ checkpoint_file = checkpoint_info.filename
+ vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
+
+ if loaded_vae_file == vae_file:
+ return
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+ else:
+ sd_model.to(devices.cpu)
+
+ sd_hijack.model_hijack.undo_hijack(sd_model)
+
+ load_vae(sd_model, vae_file)
+
+ sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
+
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ sd_model.to(devices.device)
+
+ print(f"VAE Weights loaded.")
+ return sd_model
--
cgit v1.2.3
From 95c6308ccd2e075d1fb804f5b98a4f0b07b87b7d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 2 Nov 2022 09:47:53 +0300
Subject: switch to gradio 3.8
---
modules/ui.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index a94f46ea..45cd8c3f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1572,8 +1572,7 @@ def create_ui(wrap_gradio_gpu_call):
reload_script_bodies.click(
fn=reload_scripts,
inputs=[],
- outputs=[],
- _js='function(){}'
+ outputs=[]
)
def request_restart():
@@ -1585,7 +1584,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=request_restart,
inputs=[],
outputs=[],
- _js='function(){restart_reload()}'
+ _js='restart_reload'
)
if column is not None:
--
cgit v1.2.3
From dd2108fdac2ebf943d4ac3563a49202222b88acf Mon Sep 17 00:00:00 2001
From: Maiko Tan
Date: Wed, 2 Nov 2022 15:04:35 +0800
Subject: fix: should invoke callback as well in api only mode
---
modules/script_callbacks.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index da88635b..c28e220e 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -2,6 +2,7 @@ import sys
import traceback
from collections import namedtuple
import inspect
+from typing import Optional
from fastapi import FastAPI
from gradio import Blocks
@@ -62,7 +63,7 @@ def clear_callbacks():
callbacks_image_saved.clear()
callbacks_cfg_denoiser.clear()
-def app_started_callback(demo: Blocks, app: FastAPI):
+def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callbacks_app_started:
try:
c.callback(demo, app)
--
cgit v1.2.3
From a5409a6e4bc3eaa9757a7505d4564ad8e0d899ea Mon Sep 17 00:00:00 2001
From: Muhammad Rizqi Nur
Date: Wed, 2 Nov 2022 14:37:22 +0700
Subject: Save VAE provided by cmd_opts.vae_path
---
modules/sd_vae.py | 11 ++++-------
1 file changed, 4 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 78e14e8a..71e7a6e6 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -78,27 +78,24 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
vae_list.extend(default_vae_list)
vae_list.extend(list(res.keys()))
vae_dict.clear()
- vae_dict.update(default_vae_dict)
vae_dict.update(res)
+ vae_dict.update(default_vae_dict)
return vae_list
def resolve_vae(checkpoint_file, vae_file="auto"):
global first_load, vae_dict, vae_list
- # save_settings = False
- # if vae_file argument is provided, it takes priority
+ # if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file):
vae_file = "auto"
- # save_settings = True
print("VAE provided as function argument doesn't exist")
- # for the first load, if vae-path is provided, it takes priority and failure is reported
+ # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
- # save_settings = True
- # print("Using VAE provided as command line argument")
+ shared.opts.data['sd_vae'] = get_filename(vae_file)
else:
print("VAE provided as command line argument doesn't exist")
# else, we load from settings
--
cgit v1.2.3
From 4a8cf01f6f7f072cc9c67d6b31662384b212dd9c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 2 Nov 2022 12:12:32 +0300
Subject: remove duplicate code from #3970
---
modules/api/api.py | 10 +---------
modules/shared.py | 14 ++++++++++++++
modules/ui.py | 10 +---------
3 files changed, 16 insertions(+), 18 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index b3d85e46..71c9c160 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -178,15 +178,7 @@ class Api:
progress = min(progress, 1)
- # copy from check_progress_call of ui.py
-
- if shared.parallel_processing_allowed:
- if shared.state.sampling_step - shared.state.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- if shared.opts.show_progress_grid:
- shared.state.current_image = samples_to_image_grid(shared.state.current_latent)
- else:
- shared.state.current_image = sample_to_image(shared.state.current_latent)
- shared.state.current_image_sampling_step = shared.state.sampling_step
+ shared.state.set_current_image()
current_image = None
if shared.state.current_image and not req.skip_current_image:
diff --git a/modules/shared.py b/modules/shared.py
index 04aaa648..e65f6080 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -184,6 +184,20 @@ class State:
devices.torch_gc()
+ """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
+ def set_current_image(self):
+ if not parallel_processing_allowed:
+ return
+
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None:
+ if opts.show_progress_grid:
+ self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
+ else:
+ self.current_image = sd_samplers.sample_to_image(self.current_latent)
+
+ self.current_image_sampling_step = self.sampling_step
+
+
state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
diff --git a/modules/ui.py b/modules/ui.py
index 45cd8c3f..784439ba 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -277,15 +277,7 @@ def check_progress_call(id_part):
preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0:
- if shared.parallel_processing_allowed:
-
- if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- if opts.show_progress_grid:
- shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
- else:
- shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
- shared.state.current_image_sampling_step = shared.state.sampling_step
-
+ shared.state.set_current_image()
image = shared.state.current_image
if image is None:
--
cgit v1.2.3
From 9c67408004ed132637d10321bf44565f82055fd2 Mon Sep 17 00:00:00 2001
From: timntorres <116157310+timntorres@users.noreply.github.com>
Date: Wed, 2 Nov 2022 02:18:21 -0700
Subject: Allow saving "before-highres-fix. (#4150)
* Save image/s before doing highres fix.
---
modules/processing.py | 17 +++++++++++++++--
modules/sd_samplers.py | 5 ++---
modules/shared.py | 1 +
3 files changed, 18 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index b541ee2b..2dcf4879 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -521,7 +521,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+ # Only Txt2Img needs an extra argument, n, when saving intermediate images pre highres fix.
+ if isinstance(p, StableDiffusionProcessingTxt2Img):
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, n=n)
+ else:
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
@@ -649,7 +653,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, n=0):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
@@ -685,6 +689,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ # Save a copy of the image/s before doing highres fix, if applicable.
+ if opts.save and not self.do_not_save_samples and opts.save_images_before_highres_fix:
+ for i in range(self.batch_size):
+ # This batch's ith image.
+ img = sd_samplers.sample_to_image(samples, i)
+ # Index that accounts for both batch size and batch count.
+ ind = i + self.batch_size*n
+ images.save_image(img, self.outpath_samples, "", self.all_seeds[ind], self.all_prompts[ind], opts.samples_format, suffix=f"-before-highres-fix")
+
shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 44d4c189..d7fa89a0 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -93,9 +93,8 @@ def single_sample_to_image(sample):
return Image.fromarray(x_sample)
-def sample_to_image(samples):
- return single_sample_to_image(samples[0])
-
+def sample_to_image(samples, index=0):
+ return single_sample_to_image(samples[index])
def samples_to_image_grid(samples):
return images.image_grid([single_sample_to_image(sample) for sample in samples])
diff --git a/modules/shared.py b/modules/shared.py
index e65f6080..ce991424 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -255,6 +255,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
+ "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
--
cgit v1.2.3
From eb5e82c7ddf5e72fa13b83bd1f12d3a07a4de1a4 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 2 Nov 2022 12:45:03 +0300
Subject: do not unnecessarily run VAE one more time when saving intermediate
image with hires fix
---
modules/processing.py | 39 ++++++++++++++++++++-------------------
modules/sd_samplers.py | 1 +
modules/shared.py | 2 +-
3 files changed, 22 insertions(+), 20 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 2dcf4879..3a364b5f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -199,7 +199,7 @@ class StableDiffusionProcessing():
def init(self, all_prompts, all_seeds, all_subseeds):
pass
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError()
def close(self):
@@ -521,11 +521,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
- # Only Txt2Img needs an extra argument, n, when saving intermediate images pre highres fix.
- if isinstance(p, StableDiffusionProcessingTxt2Img):
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, n=n)
- else:
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
@@ -653,7 +649,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, n=0):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
@@ -666,9 +662,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+ """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
+ def save_intermediate(image, index):
+ if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
+ return
+
+ if not isinstance(image, Image.Image):
+ image = sd_samplers.sample_to_image(image, index)
+
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
+
if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ for i in range(samples.shape[0]):
+ save_intermediate(samples, i)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@@ -678,6 +686,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
+
+ save_intermediate(image, i)
+
image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
@@ -689,15 +700,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
- # Save a copy of the image/s before doing highres fix, if applicable.
- if opts.save and not self.do_not_save_samples and opts.save_images_before_highres_fix:
- for i in range(self.batch_size):
- # This batch's ith image.
- img = sd_samplers.sample_to_image(samples, i)
- # Index that accounts for both batch size and batch count.
- ind = i + self.batch_size*n
- images.save_image(img, self.outpath_samples, "", self.all_seeds[ind], self.all_prompts[ind], opts.samples_format, suffix=f"-before-highres-fix")
-
shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
@@ -844,8 +846,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
@@ -856,4 +857,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
del x
devices.torch_gc()
- return samples
\ No newline at end of file
+ return samples
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index d7fa89a0..c7c414ef 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -96,6 +96,7 @@ def single_sample_to_image(sample):
def sample_to_image(samples, index=0):
return single_sample_to_image(samples[index])
+
def samples_to_image_grid(samples):
return images.image_grid([single_sample_to_image(sample) for sample in samples])
diff --git a/modules/shared.py b/modules/shared.py
index ce991424..01f47e38 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -256,6 +256,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
+ "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
@@ -322,7 +323,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
- "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
--
cgit v1.2.3
From f2a5cbe6f55592c4c5527b8e0bf99ea8d658f057 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 2 Nov 2022 14:41:29 +0300
Subject: fix #3986 breaking --no-half-vae
---
modules/sd_models.py | 9 +++++++++
1 file changed, 9 insertions(+)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 883639d1..5075fadb 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -183,11 +183,20 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.to(memory_format=torch.channels_last)
if not shared.cmd_opts.no_half:
+ vae = model.first_stage_model
+
+ # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
+ if shared.cmd_opts.no_half_vae:
+ model.first_stage_model = None
+
model.half()
+ model.first_stage_model = vae
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+ model.first_stage_model.to(devices.dtype_vae)
+
if shared.opts.sd_checkpoint_cache > 0:
# if PR #4035 were to get merged, restore base VAE first before caching
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
--
cgit v1.2.3
From 3178c35224467893cf8dcedb1028c59c6c23db58 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Wed, 2 Nov 2022 22:16:32 +0900
Subject: resolve conflicts
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 065b893d..959937d7 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -285,7 +285,7 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('training', "Training"), {
- "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
+ "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
--
cgit v1.2.3
From 9b5f85ac83f864310fe19c9deab6670bad695b0d Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Wed, 2 Nov 2022 22:18:04 +0900
Subject: first revert
---
modules/shared.py | 1 -
1 file changed, 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 959937d7..7e8c552b 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -286,7 +286,6 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
- "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
--
cgit v1.2.3
From 7ea5956ad5fa925f92116e8a3bf78d7f6517b654 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Wed, 2 Nov 2022 22:18:55 +0900
Subject: now add
---
modules/shared.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index d8e99f85..7ecb40d8 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -309,6 +309,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
--
cgit v1.2.3
From e21fcd72fcf147904a1df060226c4df12acf251e Mon Sep 17 00:00:00 2001
From: evshiron
Date: Wed, 2 Nov 2022 22:37:45 +0800
Subject: add back png info in image api
---
modules/api/api.py | 21 +++++++++++++++++----
1 file changed, 17 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 71c9c160..ceaf08b0 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -7,8 +7,9 @@ from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
+from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
+from PIL import PngImagePlugin
def upscaler_to_index(name: str):
@@ -31,9 +32,21 @@ def setUpscalers(req: dict):
def encode_pil_to_base64(image):
- buffer = io.BytesIO()
- image.save(buffer, format="png")
- return base64.b64encode(buffer.getvalue())
+ with io.BytesIO() as output_bytes:
+
+ # Copy any text-only metadata
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in image.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+
+ image.save(
+ output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
+ )
+ bytes_data = output_bytes.getvalue()
+ return base64.b64encode(bytes_data)
class Api:
--
cgit v1.2.3
From a9e979977a8e3999b01b6a086bb1332ab7ab308b Mon Sep 17 00:00:00 2001
From: Artem Zagidulin
Date: Wed, 2 Nov 2022 19:05:01 +0300
Subject: process_one
---
modules/processing.py | 3 +++
modules/scripts.py | 16 ++++++++++++++++
2 files changed, 19 insertions(+)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 3a364b5f..72a2ee4e 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -509,6 +509,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(prompts) == 0:
break
+ if p.scripts is not None:
+ p.scripts.process_one(p)
+
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
diff --git a/modules/scripts.py b/modules/scripts.py
index 533db45c..9f82efea 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -70,6 +70,13 @@ class Script:
pass
+ def process_one(self, p, *args):
+ """
+ Same as process(), but called for every iteration
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -294,6 +301,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def process_one(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.process_one(p, *script_args)
+ except Exception:
+ print(f"Error running process_one: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
try:
--
cgit v1.2.3
From f1b6ac64e451036fb4dfabe66d79488c56c06776 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Kyu=E2=99=A5?= <3ad4gum@gmail.com>
Date: Wed, 2 Nov 2022 17:24:42 +0100
Subject: Added option to preview Created images on batch completion.
---
modules/shared.py | 25 ++++++++++++++++---------
modules/ui.py | 2 +-
2 files changed, 17 insertions(+), 10 deletions(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index d8e99f85..d4cf32a4 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -146,6 +146,9 @@ class State:
self.interrupted = True
def nextjob(self):
+ if opts.show_progress_every_n_steps == -1:
+ self.do_set_current_image()
+
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
@@ -186,17 +189,21 @@ class State:
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
+ self.do_set_current_image()
+
+ def do_set_current_image(self):
if not parallel_processing_allowed:
return
+ if self.current_latent is None:
+ return
+
+ if opts.show_progress_grid:
+ self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
+ else:
+ self.current_image = sd_samplers.sample_to_image(self.current_latent)
- if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None:
- if opts.show_progress_grid:
- self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
- else:
- self.current_image = sd_samplers.sample_to_image(self.current_latent)
-
- self.current_image_sampling_step = self.sampling_step
-
+ self.current_image_sampling_step = self.sampling_step
state = State()
@@ -351,7 +358,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
- "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
+ "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
diff --git a/modules/ui.py b/modules/ui.py
index 2609857e..29de1e10 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -276,7 +276,7 @@ def check_progress_call(id_part):
image = gr_show(False)
preview_visibility = gr_show(False)
- if opts.show_progress_every_n_steps > 0:
+ if opts.show_progress_every_n_steps != 0:
shared.state.set_current_image()
image = shared.state.current_image
--
cgit v1.2.3
From c07f1d0d7821f85b9ce1419992c118963d605bd7 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 2 Nov 2022 16:59:10 +0000
Subject: Convert callbacks into a private map, add utility functions for
removing callbacks
---
modules/script_callbacks.py | 68 +++++++++++++++++++++++++++------------------
1 file changed, 41 insertions(+), 27 deletions(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index c28e220e..4a7fb944 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -46,25 +46,23 @@ class CFGDenoiserParams:
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
-callbacks_app_started = []
-callbacks_model_loaded = []
-callbacks_ui_tabs = []
-callbacks_ui_settings = []
-callbacks_before_image_saved = []
-callbacks_image_saved = []
-callbacks_cfg_denoiser = []
+__callback_map = dict(
+ callbacks_app_started=[],
+ callbacks_model_loaded=[],
+ callbacks_ui_tabs=[],
+ callbacks_ui_settings=[],
+ callbacks_before_image_saved=[],
+ callbacks_image_saved=[],
+ callbacks_cfg_denoiser=[]
+)
def clear_callbacks():
- callbacks_model_loaded.clear()
- callbacks_ui_tabs.clear()
- callbacks_ui_settings.clear()
- callbacks_before_image_saved.clear()
- callbacks_image_saved.clear()
- callbacks_cfg_denoiser.clear()
+ for callback_list in __callback_map.values():
+ callback_list.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
- for c in callbacks_app_started:
+ for c in __callback_map['callbacks_app_started']:
try:
c.callback(demo, app)
except Exception:
@@ -72,7 +70,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
def model_loaded_callback(sd_model):
- for c in callbacks_model_loaded:
+ for c in __callback_map['callbacks_model_loaded']:
try:
c.callback(sd_model)
except Exception:
@@ -82,7 +80,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
- for c in callbacks_ui_tabs:
+ for c in __callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
except Exception:
@@ -92,7 +90,7 @@ def ui_tabs_callback():
def ui_settings_callback():
- for c in callbacks_ui_settings:
+ for c in __callback_map['callbacks_ui_settings']:
try:
c.callback()
except Exception:
@@ -100,7 +98,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
- for c in callbacks_before_image_saved:
+ for c in __callback_map['callbacks_before_image_saved']:
try:
c.callback(params)
except Exception:
@@ -108,7 +106,7 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams):
- for c in callbacks_image_saved:
+ for c in __callback_map['callbacks_image_saved']:
try:
c.callback(params)
except Exception:
@@ -116,7 +114,7 @@ def image_saved_callback(params: ImageSaveParams):
def cfg_denoiser_callback(params: CFGDenoiserParams):
- for c in callbacks_cfg_denoiser:
+ for c in __callback_map['callbacks_cfg_denoiser']:
try:
c.callback(params)
except Exception:
@@ -129,17 +127,33 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
+
+def remove_current_script_callbacks():
+ stack = [x for x in inspect.stack() if x.filename != __file__]
+ filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+ if filename == 'unknown file':
+ return
+ for callback_list in __callback_map.values():
+ for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
+ callback_list.remove(callback_to_remove)
+
+
+def remove_callbacks_for_function(callback_func):
+ for callback_list in __callback_map.values():
+ for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
+ callback_list.remove(callback_to_remove)
+
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
- add_callback(callbacks_app_started, callback)
+ add_callback(__callback_map['callbacks_app_started'], callback)
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
- add_callback(callbacks_model_loaded, callback)
+ add_callback(__callback_map['callbacks_model_loaded'], callback)
def on_ui_tabs(callback):
@@ -152,13 +166,13 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
- add_callback(callbacks_ui_tabs, callback)
+ add_callback(__callback_map['callbacks_ui_tabs'], callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
- add_callback(callbacks_ui_settings, callback)
+ add_callback(__callback_map['callbacks_ui_settings'], callback)
def on_before_image_saved(callback):
@@ -166,7 +180,7 @@ def on_before_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
"""
- add_callback(callbacks_before_image_saved, callback)
+ add_callback(__callback_map['callbacks_before_image_saved'], callback)
def on_image_saved(callback):
@@ -174,7 +188,7 @@ def on_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
- add_callback(callbacks_image_saved, callback)
+ add_callback(__callback_map['callbacks_image_saved'], callback)
def on_cfg_denoiser(callback):
@@ -182,5 +196,5 @@ def on_cfg_denoiser(callback):
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
- add_callback(callbacks_cfg_denoiser, callback)
+ add_callback(__callback_map['callbacks_cfg_denoiser'], callback)
--
cgit v1.2.3
From de64146ad2fc2030a4cd3545676f9e18c93b8b18 Mon Sep 17 00:00:00 2001
From: Artem Zagidulin
Date: Wed, 2 Nov 2022 21:30:50 +0300
Subject: add number of itter
---
modules/processing.py | 2 +-
modules/scripts.py | 6 +++---
2 files changed, 4 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 72a2ee4e..17f4a5ec 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -510,7 +510,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
break
if p.scripts is not None:
- p.scripts.process_one(p)
+ p.scripts.process_one(p, n)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
diff --git a/modules/scripts.py b/modules/scripts.py
index 9f82efea..7aa0d56a 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -70,7 +70,7 @@ class Script:
pass
- def process_one(self, p, *args):
+ def process_one(self, p, n, *args):
"""
Same as process(), but called for every iteration
"""
@@ -301,11 +301,11 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- def process_one(self, p):
+ def process_one(self, p, n):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
- script.process_one(p, *script_args)
+ script.process_one(p, n, *script_args)
except Exception:
print(f"Error running process_one: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
--
cgit v1.2.3
From 2ac25ea64f31fd0e7dea35d27a52f3646618c3b6 Mon Sep 17 00:00:00 2001
From: digburn
Date: Wed, 2 Nov 2022 21:52:23 +0000
Subject: fix: Add required parameter to API extras route
---
modules/api/models.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/api/models.py b/modules/api/models.py
index 9ee42a17..9069c0ac 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -131,6 +131,7 @@ class ExtrasBaseRequest(BaseModel):
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
+ upscale_first: bool = Field(default=True, title="Upscale first", description="Should the upscaler run before restoring faces?")
class ExtraBaseResponse(BaseModel):
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
--
cgit v1.2.3
From 313e14de04d9955c6ad077341feceb0fc7f2f1d3 Mon Sep 17 00:00:00 2001
From: Chris OBryan <13701027+cobryan05@users.noreply.github.com>
Date: Wed, 2 Nov 2022 21:37:43 -0500
Subject: extras - skip unnecessary second hash of image
There is no need to re-hash the input image each iteration of the loop.
This also reverts PR #4026 as it was determined the cache hits it avoids
were actually valid.
---
modules/extras.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index 8e2ab35c..71b93a06 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -136,12 +136,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None
+ image_hash: str = hash(np.array(image.getdata()).tobytes())
for upscaler in params:
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
+ cache_key = LruCache.Key(image_hash=image_hash,
info_hash=hash(info),
- args_hash=hash((upscale_args, upscale_first)))
+ args_hash=hash(upscale_args))
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
--
cgit v1.2.3
From 7a2e36b583ef9eaefa44322e16faff6f9f1af169 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Thu, 3 Nov 2022 00:51:22 -0300
Subject: Add config and lists endpoints
---
modules/api/api.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++---
modules/api/models.py | 70 +++++++++++++++++++++++++++++++++++--
2 files changed, 159 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 71c9c160..ed2dce5d 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -2,14 +2,17 @@ import base64
import io
import time
import uvicorn
-from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
-from fastapi import APIRouter, Depends, HTTPException
+from threading import Lock
+from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
+from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
+from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
-
+from modules.sd_models import checkpoints_list
+from modules.realesrgan_model import get_realesrgan_models
+from typing import List
def upscaler_to_index(name: str):
try:
@@ -37,7 +40,7 @@ def encode_pil_to_base64(image):
class Api:
- def __init__(self, app, queue_lock):
+ def __init__(self, app: FastAPI, queue_lock: Lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
@@ -48,6 +51,19 @@ class Api:
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
+ self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
+ self.app.add_api_route("/sdapi/v1/info", self.get_info, methods=["GET"])
+ self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
+ self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
+ self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
+ self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
+ self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
+ self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
+ self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
+ self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
+ self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -190,6 +206,77 @@ class Api:
shared.state.interrupt()
return {}
+
+ def get_config(self):
+ options = {}
+ for key in shared.opts.data.keys():
+ metadata = shared.opts.data_labels.get(key)
+ if(metadata is not None):
+ options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
+ else:
+ options.update({key: shared.opts.data.get(key, None)})
+
+ return options
+
+ def set_config(self, req: OptionsModel):
+ reqDict = vars(req)
+ for o in reqDict:
+ setattr(shared.opts, o, reqDict[o])
+
+ shared.opts.save(shared.config_filename)
+ return
+
+ def get_cmd_flags(self):
+ return vars(shared.cmd_opts)
+
+ def get_info(self):
+
+ return {
+ "hypernetworks": [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks],
+ "face_restorers": [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers],
+ "realesrgan_models":[{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)],
+ "promp_styles":[shared.prompt_styles.styles[k] for k in shared.prompt_styles.styles],
+ "artists_categories": shared.artist_db.cats,
+ # "artists": [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
+ }
+
+ def get_samplers(self):
+ return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
+
+ def get_upscalers(self):
+ upscalers = []
+
+ for upscaler in shared.sd_upscalers:
+ u = upscaler.scaler
+ upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
+
+ return upscalers
+
+ def get_sd_models(self):
+ return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
+
+ def get_hypernetworks(self):
+ return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
+
+ def get_face_restorers(self):
+ return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
+
+ def get_realesrgan_models(self):
+ return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
+
+ def get_promp_styles(self):
+ styleList = []
+ for k in shared.prompt_styles.styles:
+ style = shared.prompt_styles.styles[k]
+ styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})
+
+ return styleList
+
+ def get_artists_categories(self):
+ return shared.artist_db.cats
+
+ def get_artists(self):
+ return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index 9ee42a17..b54b188a 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,11 +1,10 @@
import inspect
-from click import prompt
from pydantic import BaseModel, Field, create_model
-from typing import Any, Optional
+from typing import Any, Optional, Union
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
-from modules.shared import sd_upscalers
+from modules.shared import sd_upscalers, opts, parser
API_NOT_ALLOWED = [
"self",
@@ -165,3 +164,68 @@ class ProgressResponse(BaseModel):
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
+
+fields = {}
+for key, value in opts.data.items():
+ metadata = opts.data_labels.get(key)
+ optType = opts.typemap.get(type(value), type(value))
+
+ if (metadata is not None):
+ fields.update({key: (Optional[optType], Field(
+ default=metadata.default ,description=metadata.label))})
+ else:
+ fields.update({key: (Optional[optType], Field())})
+
+OptionsModel = create_model("Options", **fields)
+
+flags = {}
+_options = vars(parser)['_option_string_actions']
+for key in _options:
+ if(_options[key].dest != 'help'):
+ flag = _options[key]
+ _type = str
+ if(_options[key].default != None): _type = type(_options[key].default)
+ flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
+
+FlagsModel = create_model("Flags", **flags)
+
+class SamplerItem(BaseModel):
+ name: str = Field(title="Name")
+ aliases: list[str] = Field(title="Aliases")
+ options: dict[str, str] = Field(title="Options")
+
+class UpscalerItem(BaseModel):
+ name: str = Field(title="Name")
+ model_name: str | None = Field(title="Model Name")
+ model_path: str | None = Field(title="Path")
+ model_url: str | None = Field(title="URL")
+
+class SDModelItem(BaseModel):
+ title: str = Field(title="Title")
+ model_name: str = Field(title="Model Name")
+ hash: str = Field(title="Hash")
+ filename: str = Field(title="Filename")
+ config: str = Field(title="Config file")
+
+class HypernetworkItem(BaseModel):
+ name: str = Field(title="Name")
+ path: str | None = Field(title="Path")
+
+class FaceRestorerItem(BaseModel):
+ name: str = Field(title="Name")
+ cmd_dir: str | None = Field(title="Path")
+
+class RealesrganItem(BaseModel):
+ name: str = Field(title="Name")
+ path: str | None = Field(title="Path")
+ scale: int | None = Field(title="Scale")
+
+class PromptStyleItem(BaseModel):
+ name: str = Field(title="Name")
+ prompt: str | None = Field(title="Prompt")
+ negative_prompt: str | None = Field(title="Negative Prompt")
+
+class ArtistItem(BaseModel):
+ name: str = Field(title="Name")
+ score: float = Field(title="Score")
+ category: str = Field(title="Category")
\ No newline at end of file
--
cgit v1.2.3
From 743fffa3d6c2e9e6bb5f48093a4c88f3b53e001d Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Thu, 3 Nov 2022 00:52:01 -0300
Subject: Remove unused endpoint
---
modules/api/api.py | 12 ------------
1 file changed, 12 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index ed2dce5d..a49f3755 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -54,7 +54,6 @@ class Api:
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
- self.app.add_api_route("/sdapi/v1/info", self.get_info, methods=["GET"])
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
@@ -229,17 +228,6 @@ class Api:
def get_cmd_flags(self):
return vars(shared.cmd_opts)
- def get_info(self):
-
- return {
- "hypernetworks": [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks],
- "face_restorers": [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers],
- "realesrgan_models":[{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)],
- "promp_styles":[shared.prompt_styles.styles[k] for k in shared.prompt_styles.styles],
- "artists_categories": shared.artist_db.cats,
- # "artists": [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
- }
-
def get_samplers(self):
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
--
cgit v1.2.3
From e33d6cbddd08870e348d10a58af41fb677a39fd6 Mon Sep 17 00:00:00 2001
From: Ju1-js <40339350+Ju1-js@users.noreply.github.com>
Date: Wed, 2 Nov 2022 21:04:49 -0700
Subject: Make extension manager Remote links open a new tab
---
modules/ui_extensions.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index ab807722..a81de9a7 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -86,7 +86,7 @@ def extension_table():
code += f"""
{html.escape(ext.name)}
- {html.escape(ext.remote or '')}
+ {html.escape(ext.remote or '')}
{ext_status}
"""
--
cgit v1.2.3
From 0b143c1163a96b193a4e8512be9c5831c661a50d Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Thu, 3 Nov 2022 14:30:53 +0900
Subject: Separate .optim file from model
---
modules/hypernetworks/hypernetwork.py | 12 ++++++++----
1 file changed, 8 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 8f74cdea..63c25de8 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -161,6 +161,7 @@ class Hypernetwork:
def save(self, filename):
state_dict = {}
+ optimizer_saved_dict = {}
for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@@ -175,9 +176,10 @@ class Hypernetwork:
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
if self.optimizer_name is not None:
- state_dict['optimizer_name'] = self.optimizer_name
+ optimizer_saved_dict['optimizer_name'] = self.optimizer_name
if self.optimizer_state_dict:
- state_dict['optimizer_state_dict'] = self.optimizer_state_dict
+ optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
+ torch.save(optimizer_saved_dict, filename + '.optim')
torch.save(state_dict, filename)
@@ -198,9 +200,11 @@ class Hypernetwork:
print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}")
- self.optimizer_name = state_dict.get('optimizer_name', 'AdamW')
+
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print(f"Optimizer name is {self.optimizer_name}")
- self.optimizer_state_dict = state_dict.get('optimizer_state_dict', None)
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
if self.optimizer_state_dict:
print("Loaded existing optimizer from checkpoint")
else:
--
cgit v1.2.3
From 1764ac3c8bc482bd575987850e96630d9115e51a Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Thu, 3 Nov 2022 14:49:26 +0900
Subject: use hash to check valid optim
---
modules/hypernetworks/hypernetwork.py | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 63c25de8..4230b8cf 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -177,12 +177,13 @@ class Hypernetwork:
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
if self.optimizer_name is not None:
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
+
+ torch.save(state_dict, filename)
if self.optimizer_state_dict:
+ optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
- torch.save(state_dict, filename)
-
def load(self, filename):
self.filename = filename
if self.name is None:
@@ -204,7 +205,10 @@ class Hypernetwork:
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print(f"Optimizer name is {self.optimizer_name}")
- self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+ if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+ else:
+ self.optimizer_state_dict = None
if self.optimizer_state_dict:
print("Loaded existing optimizer from checkpoint")
else:
@@ -229,7 +233,7 @@ def list_hypernetworks(path):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
- res[name] = filename
+ res[name + f"({sd_models.model_hash(filename)})"] = filename
return res
@@ -375,6 +379,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else:
hypernetwork_dir = None
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
--
cgit v1.2.3
From 86b7fc6e5ed56327fa12b444ca2444b13eb98aa8 Mon Sep 17 00:00:00 2001
From: thesved <2893181+thesved@users.noreply.github.com>
Date: Thu, 3 Nov 2022 19:44:47 +0100
Subject: Make DDIM and PLMS work on Mac OS
Fix register_buffer error on Mac OS
---
modules/sd_hijack_inpainting.py | 19 ++++++++++++++++++-
1 file changed, 18 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index fd92a335..202b42cf 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -1,4 +1,5 @@
import torch
+import modules.devices as devices
from einops import repeat
from omegaconf import ListConfig
@@ -314,6 +315,20 @@ class LatentInpaintDiffusion(LatentDiffusion):
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
+
+
+# =================================================================================================
+# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
+# =================================================================================================
+def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ optimal_type = devices.get_optimal_device()
+ if attr.device != optimal_type:
+ if getattr(torch, 'has_mps', False):
+ attr = attr.to(device="mps", dtype=torch.float32)
+ else:
+ attr = attr.to(optimal_type)
+ setattr(self, name, attr)
def should_hijack_inpainting(checkpoint_info):
@@ -326,6 +341,8 @@ def do_inpainting_hijack():
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
+ ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
- ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
\ No newline at end of file
+ ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
+ ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
--
cgit v1.2.3
From b2c48091db394c2b7d375a33f18d90c924cd4363 Mon Sep 17 00:00:00 2001
From: Gur
Date: Fri, 4 Nov 2022 06:55:03 +0800
Subject: fixed api compatibility with python 3.8
---
modules/api/models.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/api/models.py b/modules/api/models.py
index 9ee42a17..29a934ba 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -6,6 +6,7 @@ from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers
+from typing import List
API_NOT_ALLOWED = [
"self",
@@ -109,12 +110,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
).generate_model()
class TextToImageResponse(BaseModel):
- images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
- images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
@@ -146,10 +147,10 @@ class FileData(BaseModel):
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
- imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+ imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: list[str] = Field(title="Images", description="The generated images in base64 format.")
+ images: List[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
--
cgit v1.2.3
From 8eb64dab3e9e40531f6a3fa606a1c23a62987249 Mon Sep 17 00:00:00 2001
From: digburn <115176097+digburn@users.noreply.github.com>
Date: Fri, 4 Nov 2022 00:35:18 +0000
Subject: fix: correct default val of upscale_first to False
---
modules/api/models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/api/models.py b/modules/api/models.py
index 9069c0ac..68fb45c6 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -131,7 +131,7 @@ class ExtrasBaseRequest(BaseModel):
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
- upscale_first: bool = Field(default=True, title="Upscale first", description="Should the upscaler run before restoring faces?")
+ upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
class ExtraBaseResponse(BaseModel):
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
--
cgit v1.2.3
From 3780ad3ad837dd406da39eebd5d91009b5a58445 Mon Sep 17 00:00:00 2001
From: digburn
Date: Fri, 4 Nov 2022 00:40:21 +0000
Subject: fix: loading models without vae from cache
---
modules/sd_models.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 5075fadb..ae427a5c 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -204,8 +204,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoints_loaded.popitem(last=False) # LRU
else:
- vae_name = sd_vae.get_filename(vae_file)
- print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache")
+ vae_name = sd_vae.get_filename(vae_file) if vae_file else None
+ vae_message = f" with {vae_name} VAE" if vae_name else ""
+ print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
checkpoints_loaded.move_to_end(checkpoint_key)
model.load_state_dict(checkpoints_loaded[checkpoint_key])
--
cgit v1.2.3
From e533ff61c1baa4ad047f9c8dc05c17b64ee89ddf Mon Sep 17 00:00:00 2001
From: timntorres
Date: Thu, 3 Nov 2022 22:28:22 -0700
Subject: Lift extras generate button a la #4246.
---
modules/ui.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 2609857e..6461002a 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1052,6 +1052,8 @@ def create_ui(wrap_gradio_gpu_call):
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
show_extras_results = gr.Checkbox(label='Show result images', value=True)
+ submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
+
with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'):
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
@@ -1079,8 +1081,6 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
- submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
submit.click(
--
cgit v1.2.3
From 4dd898b8c15e342f817d3fb1c8dc9f2d5d111022 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 4 Nov 2022 08:38:11 +0300
Subject: do not mess with components' visibility for scripts; instead create
group components and show/hide those; this will break scripts that create
invisible components and rely on UI but the earlier i make this change the
better
---
modules/scripts.py | 34 ++++++++++++++++++----------------
1 file changed, 18 insertions(+), 16 deletions(-)
(limited to 'modules')
diff --git a/modules/scripts.py b/modules/scripts.py
index 533db45c..28ce07f4 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -18,6 +18,9 @@ class Script:
args_to = None
alwayson = False
+ """A gr.Group component that has all script's UI inside it"""
+ group = None
+
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
@@ -218,8 +221,6 @@ class ScriptRunner:
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
- if not script.alwayson:
- control.visible = False
if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields
@@ -229,40 +230,41 @@ class ScriptRunner:
script.args_to = len(inputs)
for script in self.alwayson_scripts:
- with gr.Group():
+ with gr.Group() as group:
create_script_ui(script, inputs, inputs_alwayson)
+ script.group = group
+
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs[0] = dropdown
for script in self.selectable_scripts:
- create_script_ui(script, inputs, inputs_alwayson)
+ with gr.Group(visible=False) as group:
+ create_script_ui(script, inputs, inputs_alwayson)
+
+ script.group = group
def select_script(script_index):
- if 0 < script_index <= len(self.selectable_scripts):
- script = self.selectable_scripts[script_index-1]
- args_from = script.args_from
- args_to = script.args_to
- else:
- args_from = 0
- args_to = 0
+ selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
- return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
+ return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
def init_field(title):
+ """called when an initial value is set from ui-config.json to show script's UI components"""
+
if title == 'None':
return
+
script_index = self.titles.index(title)
- script = self.selectable_scripts[script_index]
- for i in range(script.args_from, script.args_to):
- inputs[i].visible = True
+ self.selectable_scripts[script_index].group.visible = True
dropdown.init_field = init_field
+
dropdown.change(
fn=select_script,
inputs=[dropdown],
- outputs=inputs
+ outputs=[script.group for script in self.selectable_scripts]
)
return inputs
--
cgit v1.2.3
From f2b69709eaff88fc3a2bd49585556ec0883bf5ea Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 4 Nov 2022 09:42:25 +0300
Subject: move option access checking to options class out of various places
scattered through code
---
modules/processing.py | 4 ++--
modules/shared.py | 11 +++++++++++
modules/ui.py | 20 +++++---------------
3 files changed, 18 insertions(+), 17 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 2168208c..a46e592d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -418,13 +418,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
- opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
+ setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model, hypernet impossible
res = process_images_inner(p)
finally:
for k, v in stored_opts.items():
- opts.data[k] = v
+ setattr(opts, k, v)
return res
diff --git a/modules/shared.py b/modules/shared.py
index d8e99f85..024c771a 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -396,6 +396,15 @@ class Options:
def __setattr__(self, key, value):
if self.data is not None:
if key in self.data or key in self.data_labels:
+ assert not cmd_opts.freeze_settings, "changing settings is disabled"
+
+ comp_args = opts.data_labels[key].component_args
+ if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
+ raise RuntimeError(f"not possible to set {key} because it is restricted")
+
+ if cmd_opts.hide_ui_dir_config and key in restricted_opts:
+ raise RuntimeError(f"not possible to set {key} because it is restricted")
+
self.data[key] = value
return
@@ -412,6 +421,8 @@ class Options:
return super(Options, self).__getattribute__(item)
def save(self, filename):
+ assert not cmd_opts.freeze_settings, "saving settings is disabled"
+
with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file, indent=4)
diff --git a/modules/ui.py b/modules/ui.py
index b2b1c854..633b56ef 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1438,8 +1438,6 @@ def create_ui(wrap_gradio_gpu_call):
def run_settings(*args):
changed = 0
- assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
-
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
@@ -1448,15 +1446,9 @@ def create_ui(wrap_gradio_gpu_call):
if comp == dummy_component:
continue
- comp_args = opts.data_labels[key].component_args
- if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
- continue
-
- if cmd_opts.hide_ui_dir_config and key in restricted_opts:
- continue
-
oldval = opts.data.get(key, None)
- opts.data[key] = value
+
+ setattr(opts, key, value)
if oldval != value:
if opts.data_labels[key].onchange is not None:
@@ -1469,17 +1461,15 @@ def create_ui(wrap_gradio_gpu_call):
return f'{changed} settings changed.', opts.dumpjson()
def run_settings_single(value, key):
- assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
-
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
oldval = opts.data.get(key, None)
- if cmd_opts.hide_ui_dir_config and key in restricted_opts:
+ try:
+ setattr(opts, key, value)
+ except Exception:
return gr.update(value=oldval), opts.dumpjson()
- opts.data[key] = value
-
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
--
cgit v1.2.3
From 0abb39f461baa343ae7c23abffb261e57c3168d4 Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Fri, 4 Nov 2022 15:47:19 +0900
Subject: resolve conflict - first revert
---
modules/hypernetworks/hypernetwork.py | 123 ++++++++++++++--------------------
1 file changed, 52 insertions(+), 71 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 4230b8cf..674fcedd 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -21,7 +21,6 @@ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_norm
from collections import defaultdict, deque
from statistics import stdev, mean
-optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
@@ -34,9 +33,12 @@ class HypernetworkModule(torch.nn.Module):
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
}
- activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
+ activation_dict.update(
+ {cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if
+ inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
- def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
+ add_layer_norm=False, use_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -47,7 +49,7 @@ class HypernetworkModule(torch.nn.Module):
for i in range(len(layer_structure) - 1):
# Add a fully-connected layer
- linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i + 1])))
# Add an activation func
if activation_func == "linear" or activation_func is None:
@@ -59,7 +61,7 @@ class HypernetworkModule(torch.nn.Module):
# Add layer normalization
if add_layer_norm:
- linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i + 1])))
# Add dropout expect last layer
if use_dropout and i < len(layer_structure) - 3:
@@ -128,7 +130,8 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None,
+ add_layer_norm=False, use_dropout=False):
self.filename = None
self.name = name
self.layers = {}
@@ -140,13 +143,13 @@ class Hypernetwork:
self.weight_init = weight_init
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
- self.optimizer_name = None
- self.optimizer_state_dict = None
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout),
)
def weights(self):
@@ -161,7 +164,6 @@ class Hypernetwork:
def save(self, filename):
state_dict = {}
- optimizer_saved_dict = {}
for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@@ -175,14 +177,8 @@ class Hypernetwork:
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
- if self.optimizer_name is not None:
- optimizer_saved_dict['optimizer_name'] = self.optimizer_name
torch.save(state_dict, filename)
- if self.optimizer_state_dict:
- optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
- optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
- torch.save(optimizer_saved_dict, filename + '.optim')
def load(self, filename):
self.filename = filename
@@ -202,23 +198,13 @@ class Hypernetwork:
self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}")
- optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
- self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
- print(f"Optimizer name is {self.optimizer_name}")
- if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
- self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
- else:
- self.optimizer_state_dict = None
- if self.optimizer_state_dict:
- print("Loaded existing optimizer from checkpoint")
- else:
- print("No saved optimizer exists in checkpoint")
-
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
+ self.add_layer_norm, self.use_dropout),
)
self.name = state_dict.get('name', self.name)
@@ -233,7 +219,7 @@ def list_hypernetworks(path):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
- res[name + f"({sd_models.model_hash(filename)})"] = filename
+ res[name] = filename
return res
@@ -330,7 +316,7 @@ def statistics(data):
std = 0
else:
std = stdev(data)
- total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
+ total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std / (len(data) ** 0.5):.3f})"
recent_data = data[-32:]
if len(recent_data) < 2:
std = 0
@@ -340,7 +326,7 @@ def statistics(data):
return total_information, recent_information
-def report_statistics(loss_info:dict):
+def report_statistics(loss_info: dict):
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
for key in keys:
try:
@@ -352,14 +338,18 @@ def report_statistics(loss_info:dict):
print(e)
-
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
+ training_height, steps, create_image_every, save_hypernetwork_every, template_file,
+ preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
+ preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
- textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps,
+ save_hypernetwork_every, create_image_every, log_directory,
+ name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@@ -379,7 +369,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else:
hypernetwork_dir = None
- hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
@@ -395,39 +384,34 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
-
+
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
+ height=training_height,
+ repeats=shared.opts.training_image_repeats_per_epoch,
+ placeholder_token=hypernetwork_name,
+ model=shared.sd_model, device=devices.device,
+ template_file=template_file, include_cond=True,
+ batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
size = len(ds.indexes)
- loss_dict = defaultdict(lambda : deque(maxlen = 1024))
+ loss_dict = defaultdict(lambda: deque(maxlen=1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
-
+
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
- # Here we use optimizer from saved HN, or we can specify as UI option.
- if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict:
- optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
- else:
- print(f"Optimizer type {optimizer_name} is not defined!")
- optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
- optimizer_name = 'AdamW'
- if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
- try:
- optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
- except RuntimeError as e:
- print("Cannot resume from saved optimizer!")
- print(e)
+ # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
steps_without_grad = 0
@@ -441,7 +425,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if len(loss_dict) > 0:
previous_mean_losses = [i[-1] for i in loss_dict.values()]
previous_mean_loss = mean(previous_mean_losses)
-
+
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
@@ -460,7 +444,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item()
for entry in entries:
loss_dict[entry.filename].append(loss.item())
-
+
optimizer.zero_grad()
weights[0].grad = None
loss.backward()
@@ -475,9 +459,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
steps_done = hypernetwork.step + 1
- if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
-
+
if len(previous_mean_losses) > 1:
std = stdev(previous_mean_losses)
else:
@@ -489,11 +473,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
- hypernetwork.optimizer_name = optimizer_name
- if shared.opts.save_optimizer_state:
- hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
- hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
+
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
@@ -529,7 +510,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
preview_text = p.prompt
processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images)>0 else None
+ image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -537,7 +518,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None:
shared.state.current_image = image
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
+ shared.opts.samples_format, processed.infotexts[0],
+ p=p, forced_filename=forced_filename,
+ save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
@@ -551,15 +535,12 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
+
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
- hypernetwork.optimizer_name = optimizer_name
- if shared.opts.save_optimizer_state:
- hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
- del optimizer
- hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
+
return hypernetwork, filename
@@ -576,4 +557,4 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
hypernetwork.sd_checkpoint = old_sd_checkpoint
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
hypernetwork.name = old_hypernetwork_name
- raise
+ raise
\ No newline at end of file
--
cgit v1.2.3
From 0d07cbfa15d34294a4fa22d74359cdd6fe2f799c Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Fri, 4 Nov 2022 15:50:54 +0900
Subject: I blame code autocomplete
---
modules/hypernetworks/hypernetwork.py | 76 +++++++++++++----------------------
1 file changed, 27 insertions(+), 49 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 674fcedd..a11e01d6 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -33,12 +33,9 @@ class HypernetworkModule(torch.nn.Module):
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
}
- activation_dict.update(
- {cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if
- inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
+ activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
- def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
- add_layer_norm=False, use_dropout=False):
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -49,7 +46,7 @@ class HypernetworkModule(torch.nn.Module):
for i in range(len(layer_structure) - 1):
# Add a fully-connected layer
- linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i + 1])))
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# Add an activation func
if activation_func == "linear" or activation_func is None:
@@ -61,7 +58,7 @@ class HypernetworkModule(torch.nn.Module):
# Add layer normalization
if add_layer_norm:
- linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i + 1])))
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout expect last layer
if use_dropout and i < len(layer_structure) - 3:
@@ -130,8 +127,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None,
- add_layer_norm=False, use_dropout=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
self.filename = None
self.name = name
self.layers = {}
@@ -146,10 +142,8 @@ class Hypernetwork:
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
)
def weights(self):
@@ -196,15 +190,13 @@ class Hypernetwork:
self.add_layer_norm = state_dict.get('is_layer_norm', False)
print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False)
- print(f"Dropout usage is set to {self.use_dropout}")
+ print(f"Dropout usage is set to {self.use_dropout}" )
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.use_dropout),
- HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
)
self.name = state_dict.get('name', self.name)
@@ -316,7 +308,7 @@ def statistics(data):
std = 0
else:
std = stdev(data)
- total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std / (len(data) ** 0.5):.3f})"
+ total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
recent_data = data[-32:]
if len(recent_data) < 2:
std = 0
@@ -326,7 +318,7 @@ def statistics(data):
return total_information, recent_information
-def report_statistics(loss_info: dict):
+def report_statistics(loss_info:dict):
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
for key in keys:
try:
@@ -338,18 +330,14 @@ def report_statistics(loss_info: dict):
print(e)
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
- training_height, steps, create_image_every, save_hypernetwork_every, template_file,
- preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
- preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
- textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps,
- save_hypernetwork_every, create_image_every, log_directory,
- name="hypernetwork")
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@@ -384,29 +372,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
-
+
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
- height=training_height,
- repeats=shared.opts.training_image_repeats_per_epoch,
- placeholder_token=hypernetwork_name,
- model=shared.sd_model, device=devices.device,
- template_file=template_file, include_cond=True,
- batch_size=batch_size)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
size = len(ds.indexes)
- loss_dict = defaultdict(lambda: deque(maxlen=1024))
+ loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
-
+
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
@@ -425,7 +407,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if len(loss_dict) > 0:
previous_mean_losses = [i[-1] for i in loss_dict.values()]
previous_mean_loss = mean(previous_mean_losses)
-
+
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
@@ -444,7 +426,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item()
for entry in entries:
loss_dict[entry.filename].append(loss.item())
-
+
optimizer.zero_grad()
weights[0].grad = None
loss.backward()
@@ -459,9 +441,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
steps_done = hypernetwork.step + 1
- if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
-
+
if len(previous_mean_losses) > 1:
std = stdev(previous_mean_losses)
else:
@@ -510,7 +492,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
preview_text = p.prompt
processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images) > 0 else None
+ image = processed.images[0] if len(processed.images)>0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -518,10 +500,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None:
shared.state.current_image = image
- last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
- shared.opts.samples_format, processed.infotexts[0],
- p=p, forced_filename=forced_filename,
- save_to_dirs=False)
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
@@ -535,7 +514,7 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
-
+
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
@@ -543,7 +522,6 @@ Last saved image: {html.escape(last_saved_image)}
return hypernetwork, filename
-
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_hypernetwork_name = hypernetwork.name
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
@@ -557,4 +535,4 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
hypernetwork.sd_checkpoint = old_sd_checkpoint
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
hypernetwork.name = old_hypernetwork_name
- raise
\ No newline at end of file
+ raise
--
cgit v1.2.3
From 283249d2390f0f3a1c8a55d5d9aa551e3e9b2f9c Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Fri, 4 Nov 2022 15:57:17 +0900
Subject: apply
---
modules/hypernetworks/hypernetwork.py | 54 +++++++++++++++++++++++++++++++----
1 file changed, 49 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 6e1a10cf..de8688a9 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -22,6 +22,8 @@ from collections import defaultdict, deque
from statistics import stdev, mean
+optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
+
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -142,6 +144,8 @@ class Hypernetwork:
self.use_dropout = use_dropout
self.activate_output = activate_output
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
+ self.optimizer_name = None
+ self.optimizer_state_dict = None
for size in enable_sizes or []:
self.layers[size] = (
@@ -163,6 +167,7 @@ class Hypernetwork:
def save(self, filename):
state_dict = {}
+ optimizer_saved_dict = {}
for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@@ -178,8 +183,15 @@ class Hypernetwork:
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output
state_dict['last_layer_dropout'] = self.last_layer_dropout
-
+
+ if self.optimizer_name is not None:
+ optimizer_saved_dict['optimizer_name'] = self.optimizer_name
+
torch.save(state_dict, filename)
+ if self.optimizer_state_dict:
+ optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
+ optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
+ torch.save(optimizer_saved_dict, filename + '.optim')
def load(self, filename):
self.filename = filename
@@ -202,6 +214,18 @@ class Hypernetwork:
print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
+ print(f"Optimizer name is {self.optimizer_name}")
+ if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
+ self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+ else:
+ self.optimizer_state_dict = None
+ if self.optimizer_state_dict:
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
+
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
@@ -223,7 +247,7 @@ def list_hypernetworks(path):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
- res[name] = filename
+ res[name + f"({sd_models.model_hash(filename)})"] = filename
return res
@@ -369,6 +393,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else:
hypernetwork_dir = None
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
@@ -404,8 +429,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
- # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
- optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ # Here we use optimizer from saved HN, or we can specify as UI option.
+ if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict:
+ optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
+ else:
+ print(f"Optimizer type {optimizer_name} is not defined!")
+ optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
+ optimizer_name = 'AdamW'
+ if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
+ try:
+ optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
+ except RuntimeError as e:
+ print("Cannot resume from saved optimizer!")
+ print(e)
steps_without_grad = 0
@@ -467,7 +503,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
@@ -530,8 +570,12 @@ Last saved image: {html.escape(last_saved_image)}
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+ hypernetwork.optimizer_name = optimizer_name
+ if shared.opts.save_optimizer_state:
+ hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
-
+ del optimizer
+ hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
--
cgit v1.2.3
From f5d394214d6ee74a682d0a1016bcbebc4b43c13a Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Fri, 4 Nov 2022 16:04:03 +0900
Subject: split before declaring file name
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index de8688a9..9b6a3e62 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -382,6 +382,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
+ hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
@@ -393,7 +394,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else:
hypernetwork_dir = None
- hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
--
cgit v1.2.3
From 1ca0bcd3a7003dd2c1324de7d97fd2a6fc5ddc53 Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Fri, 4 Nov 2022 16:09:19 +0900
Subject: only save if option is enabled
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 9b6a3e62..b1f308e2 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -188,7 +188,7 @@ class Hypernetwork:
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
torch.save(state_dict, filename)
- if self.optimizer_state_dict:
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict:
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
--
cgit v1.2.3
From ccf1a15412ef6b518f9f54cc26a0ee5edf458108 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 4 Nov 2022 10:16:19 +0300
Subject: add an option to enable installing extensions with --listen or
--share
---
modules/shared.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 024c771a..0a39cdf2 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -44,6 +44,7 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
+parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
@@ -99,7 +100,7 @@ restricted_opts = {
"outdir_save",
}
-cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
+cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
--
cgit v1.2.3
From 321e13ca176b256177c4a752d1f2bbee79b5532e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 4 Nov 2022 10:35:30 +0300
Subject: produce a readable error message when setting an option fails on the
settings screen
---
modules/ui.py | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 633b56ef..3ac7540c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1439,8 +1439,7 @@ def create_ui(wrap_gradio_gpu_call):
changed = 0
for key, value, comp in zip(opts.data_labels.keys(), args, components):
- if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
- return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
+ assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp == dummy_component:
@@ -1458,7 +1457,7 @@ def create_ui(wrap_gradio_gpu_call):
opts.save(shared.config_filename)
- return f'{changed} settings changed.', opts.dumpjson()
+ return opts.dumpjson(), f'{changed} settings changed.'
def run_settings_single(value, key):
if not opts.same_type(value, opts.data_labels[key].default):
@@ -1622,9 +1621,9 @@ def create_ui(wrap_gradio_gpu_call):
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
- fn=run_settings,
+ fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
inputs=components,
- outputs=[result, text_settings],
+ outputs=[text_settings, result],
)
for i, k, item in quicksettings_list:
--
cgit v1.2.3
From f674c488d9701e577e2aaf25e331fb44ada4f1ef Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 4 Nov 2022 10:45:34 +0300
Subject: bugfix: save image for hires fix BEFORE upscaling latent space
---
modules/processing.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index a46e592d..7a2fc218 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -665,17 +665,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix:
+ for i in range(samples.shape[0]):
+ save_intermediate(samples, i)
+
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
-
+
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
else:
image_conditioning = self.txt2img_image_conditioning(samples)
-
- for i in range(samples.shape[0]):
- save_intermediate(samples, i)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
--
cgit v1.2.3
From 7278897982bfb640ee95f144c97ed25fb3f77ea3 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Fri, 4 Nov 2022 17:12:28 +0900
Subject: Update shared.py
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 4d6e1c8b..6e7a02e0 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -309,7 +309,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
- "save_optimizer_state": OptionInfo(False, "Saves Optimizer state with checkpoints. This will cause file size to increase VERY much."),
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
--
cgit v1.2.3
From 99043f33606d3057f83ea52a403e10cd29d1f7e7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 4 Nov 2022 11:20:42 +0300
Subject: fix one of previous merges breaking the program
---
modules/sd_models.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 63e07a12..34c57bfa 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -167,6 +167,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
sd_vae.restore_base_vae(model)
checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
+ vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
+
if checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
--
cgit v1.2.3
From eeb07330131012c0294afb79165b90270679b9c7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 4 Nov 2022 11:21:40 +0300
Subject: change process_one virtual function for script to process_batch, add
extra args and docs
---
modules/processing.py | 2 +-
modules/scripts.py | 16 +++++++++++-----
2 files changed, 12 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index e20d8fc4..03c9143d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -502,7 +502,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
break
if p.scripts is not None:
- p.scripts.process_one(p, n)
+ p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
diff --git a/modules/scripts.py b/modules/scripts.py
index 75e47cd2..366c90d7 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -73,9 +73,15 @@ class Script:
pass
- def process_one(self, p, n, *args):
+ def process_batch(self, p, *args, **kwargs):
"""
- Same as process(), but called for every iteration
+ Same as process(), but called for every batch.
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
"""
pass
@@ -303,13 +309,13 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- def process_one(self, p, n):
+ def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
- script.process_one(p, n, *script_args)
+ script.process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running process_one: {script.filename}", file=sys.stderr)
+ print(f"Error running process_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def postprocess(self, p, processed):
--
cgit v1.2.3
From 39541d7725bc42f456a604b07c50aba503a5a09a Mon Sep 17 00:00:00 2001
From: Fampai <>
Date: Fri, 4 Nov 2022 04:50:22 -0400
Subject: Fixes race condition in training when VAE is unloaded
set_current_image can attempt to use the VAE when it is unloaded to
the CPU while training
---
modules/hypernetworks/hypernetwork.py | 4 ++++
modules/textual_inversion/textual_inversion.py | 5 +++++
2 files changed, 9 insertions(+)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 6e1a10cf..fcb96059 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -390,7 +390,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
if unload:
+ shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
@@ -531,6 +534,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return hypernetwork, filename
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 0aeb0459..55892c57 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -273,7 +273,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
+
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
if unload:
+ shared.parallel_processing_allowed = False
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
@@ -410,6 +414,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return embedding, filename
--
cgit v1.2.3
From 821e2b883dbb42a187bc37379175cd55b7cd7e81 Mon Sep 17 00:00:00 2001
From: TinkTheBoush
Date: Fri, 4 Nov 2022 19:39:03 +0900
Subject: change option position to Training setting
---
modules/hypernetworks/hypernetwork.py | 4 ++--
modules/shared.py | 1 +
modules/textual_inversion/dataset.py | 5 ++---
modules/textual_inversion/textual_inversion.py | 4 ++--
4 files changed, 7 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 7630fb81..a11e01d6 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -331,7 +331,7 @@ def report_statistics(loss_info:dict):
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -376,7 +376,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
diff --git a/modules/shared.py b/modules/shared.py
index 1ccb269a..e1d9bdf1 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -290,6 +290,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
+ "shuffle_tags": OptionInfo(False, "Shuffleing tags by "," when create texts."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index e9d97cc1..df278dc2 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -24,7 +24,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
@@ -33,7 +33,6 @@ class PersonalizedBase(Dataset):
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
- self.shuffle_tags = shuffle_tags
self.dataset = []
@@ -99,7 +98,7 @@ class PersonalizedBase(Dataset):
def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
- if self.tag_shuffle:
+ if shared.opts.shuffle_tags:
tags = filename_text.split(',')
random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags))
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 82dde931..0aeb0459 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -224,7 +224,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
@@ -272,7 +272,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
--
cgit v1.2.3
From 45b65e87e0ef64b3e457f7d20c62d591cdcd0e7b Mon Sep 17 00:00:00 2001
From: TinkTheBoush
Date: Fri, 4 Nov 2022 19:48:28 +0900
Subject: remove ui option
---
modules/ui.py | 3 ---
1 file changed, 3 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 6f3836c6..45cd8c3f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1269,7 +1269,6 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
- shuffle_tags = gr.Checkbox(label='Shuffleing tags by "," when create texts', value=True)
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
@@ -1364,7 +1363,6 @@ def create_ui(wrap_gradio_gpu_call):
template_file,
save_image_with_stored_embedding,
preview_from_txt2img,
- shuffle_tags,
*txt2img_preview_params,
],
outputs=[
@@ -1389,7 +1387,6 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every,
template_file,
preview_from_txt2img,
- shuffle_tags,
*txt2img_preview_params,
],
outputs=[
--
cgit v1.2.3
From fd62727893f9face287b0a9620251afaa38a627d Mon Sep 17 00:00:00 2001
From: Isaac Poulton
Date: Fri, 4 Nov 2022 18:34:35 +0700
Subject: Sort hypernetworks
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 6e1a10cf..f1f04a70 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -224,7 +224,7 @@ def list_hypernetworks(path):
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name] = filename
- return res
+ return dict(sorted(res.items()))
def load_hypernetwork(filename):
--
cgit v1.2.3
From c3cd0d7a86f35a5bfc58fdc3ecfaf203c0aee06f Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Fri, 4 Nov 2022 12:19:16 +0000
Subject: Should be one underscore for module privates not two
---
modules/script_callbacks.py | 37 ++++++++++++++++++-------------------
1 file changed, 18 insertions(+), 19 deletions(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 4a7fb944..83da7ca4 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -46,7 +46,7 @@ class CFGDenoiserParams:
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
-__callback_map = dict(
+_callback_map = dict(
callbacks_app_started=[],
callbacks_model_loaded=[],
callbacks_ui_tabs=[],
@@ -58,11 +58,11 @@ __callback_map = dict(
def clear_callbacks():
- for callback_list in __callback_map.values():
+ for callback_list in _callback_map.values():
callback_list.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
- for c in __callback_map['callbacks_app_started']:
+ for c in _callback_map['callbacks_app_started']:
try:
c.callback(demo, app)
except Exception:
@@ -70,7 +70,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
def model_loaded_callback(sd_model):
- for c in __callback_map['callbacks_model_loaded']:
+ for c in _callback_map['callbacks_model_loaded']:
try:
c.callback(sd_model)
except Exception:
@@ -80,7 +80,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
- for c in __callback_map['callbacks_ui_tabs']:
+ for c in _callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
except Exception:
@@ -90,7 +90,7 @@ def ui_tabs_callback():
def ui_settings_callback():
- for c in __callback_map['callbacks_ui_settings']:
+ for c in _callback_map['callbacks_ui_settings']:
try:
c.callback()
except Exception:
@@ -98,7 +98,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
- for c in __callback_map['callbacks_before_image_saved']:
+ for c in _callback_map['callbacks_before_image_saved']:
try:
c.callback(params)
except Exception:
@@ -106,7 +106,7 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams):
- for c in __callback_map['callbacks_image_saved']:
+ for c in _callback_map['callbacks_image_saved']:
try:
c.callback(params)
except Exception:
@@ -114,7 +114,7 @@ def image_saved_callback(params: ImageSaveParams):
def cfg_denoiser_callback(params: CFGDenoiserParams):
- for c in __callback_map['callbacks_cfg_denoiser']:
+ for c in _callback_map['callbacks_cfg_denoiser']:
try:
c.callback(params)
except Exception:
@@ -133,13 +133,13 @@ def remove_current_script_callbacks():
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
if filename == 'unknown file':
return
- for callback_list in __callback_map.values():
+ for callback_list in _callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
callback_list.remove(callback_to_remove)
def remove_callbacks_for_function(callback_func):
- for callback_list in __callback_map.values():
+ for callback_list in _callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
callback_list.remove(callback_to_remove)
@@ -147,13 +147,13 @@ def remove_callbacks_for_function(callback_func):
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
- add_callback(__callback_map['callbacks_app_started'], callback)
+ add_callback(_callback_map['callbacks_app_started'], callback)
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
- add_callback(__callback_map['callbacks_model_loaded'], callback)
+ add_callback(_callback_map['callbacks_model_loaded'], callback)
def on_ui_tabs(callback):
@@ -166,13 +166,13 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
- add_callback(__callback_map['callbacks_ui_tabs'], callback)
+ add_callback(_callback_map['callbacks_ui_tabs'], callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
- add_callback(__callback_map['callbacks_ui_settings'], callback)
+ add_callback(_callback_map['callbacks_ui_settings'], callback)
def on_before_image_saved(callback):
@@ -180,7 +180,7 @@ def on_before_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
"""
- add_callback(__callback_map['callbacks_before_image_saved'], callback)
+ add_callback(_callback_map['callbacks_before_image_saved'], callback)
def on_image_saved(callback):
@@ -188,7 +188,7 @@ def on_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
- add_callback(__callback_map['callbacks_image_saved'], callback)
+ add_callback(_callback_map['callbacks_image_saved'], callback)
def on_cfg_denoiser(callback):
@@ -196,5 +196,4 @@ def on_cfg_denoiser(callback):
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
- add_callback(__callback_map['callbacks_cfg_denoiser'], callback)
-
+ add_callback(_callback_map['callbacks_cfg_denoiser'], callback)
--
cgit v1.2.3
From f316280ad3634a2343b086a6de0bfcd473e18599 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 4 Nov 2022 16:48:40 +0300
Subject: fix the error that prevents from setting some options
---
modules/shared.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index a9e28b9c..962115f6 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -406,7 +406,8 @@ class Options:
if key in self.data or key in self.data_labels:
assert not cmd_opts.freeze_settings, "changing settings is disabled"
- comp_args = opts.data_labels[key].component_args
+ info = opts.data_labels.get(key, None)
+ comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted")
--
cgit v1.2.3
From 116bcf730ade8d3ac5d76d04c5887b6bba000970 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 4 Nov 2022 16:48:46 +0300
Subject: disable setting options via API until it is fixed by the author
---
modules/api/api.py | 4 ++++
1 file changed, 4 insertions(+)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index a49f3755..8a7ab2f5 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -218,6 +218,10 @@ class Api:
return options
def set_config(self, req: OptionsModel):
+ # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
+ # overwrite all options with default values.
+ raise RuntimeError('Setting options via API is not supported')
+
reqDict = vars(req)
for o in reqDict:
setattr(shared.opts, o, reqDict[o])
--
cgit v1.2.3
From 08feb4c364e8b2aed929fd7d22dfa21a93d78b2c Mon Sep 17 00:00:00 2001
From: Isaac Poulton
Date: Fri, 4 Nov 2022 20:53:11 +0700
Subject: Sort straight out of the glob
---
modules/hypernetworks/hypernetwork.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index f1f04a70..a441ab10 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -219,12 +219,12 @@ class Hypernetwork:
def list_hypernetworks(path):
res = {}
- for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
+ for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name] = filename
- return dict(sorted(res.items()))
+ return res
def load_hypernetwork(filename):
--
cgit v1.2.3
From 5844ef8a9a165e0f456a4658bda830282cf5a55e Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Fri, 4 Nov 2022 16:02:25 +0000
Subject: remove private underscore indicator
---
modules/script_callbacks.py | 36 ++++++++++++++++++------------------
1 file changed, 18 insertions(+), 18 deletions(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 83da7ca4..74dfb880 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -46,7 +46,7 @@ class CFGDenoiserParams:
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
-_callback_map = dict(
+callback_map = dict(
callbacks_app_started=[],
callbacks_model_loaded=[],
callbacks_ui_tabs=[],
@@ -58,11 +58,11 @@ _callback_map = dict(
def clear_callbacks():
- for callback_list in _callback_map.values():
+ for callback_list in callback_map.values():
callback_list.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
- for c in _callback_map['callbacks_app_started']:
+ for c in callback_map['callbacks_app_started']:
try:
c.callback(demo, app)
except Exception:
@@ -70,7 +70,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
def model_loaded_callback(sd_model):
- for c in _callback_map['callbacks_model_loaded']:
+ for c in callback_map['callbacks_model_loaded']:
try:
c.callback(sd_model)
except Exception:
@@ -80,7 +80,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
- for c in _callback_map['callbacks_ui_tabs']:
+ for c in callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
except Exception:
@@ -90,7 +90,7 @@ def ui_tabs_callback():
def ui_settings_callback():
- for c in _callback_map['callbacks_ui_settings']:
+ for c in callback_map['callbacks_ui_settings']:
try:
c.callback()
except Exception:
@@ -98,7 +98,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
- for c in _callback_map['callbacks_before_image_saved']:
+ for c in callback_map['callbacks_before_image_saved']:
try:
c.callback(params)
except Exception:
@@ -106,7 +106,7 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams):
- for c in _callback_map['callbacks_image_saved']:
+ for c in callback_map['callbacks_image_saved']:
try:
c.callback(params)
except Exception:
@@ -114,7 +114,7 @@ def image_saved_callback(params: ImageSaveParams):
def cfg_denoiser_callback(params: CFGDenoiserParams):
- for c in _callback_map['callbacks_cfg_denoiser']:
+ for c in callback_map['callbacks_cfg_denoiser']:
try:
c.callback(params)
except Exception:
@@ -133,13 +133,13 @@ def remove_current_script_callbacks():
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
if filename == 'unknown file':
return
- for callback_list in _callback_map.values():
+ for callback_list in callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
callback_list.remove(callback_to_remove)
def remove_callbacks_for_function(callback_func):
- for callback_list in _callback_map.values():
+ for callback_list in callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
callback_list.remove(callback_to_remove)
@@ -147,13 +147,13 @@ def remove_callbacks_for_function(callback_func):
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
- add_callback(_callback_map['callbacks_app_started'], callback)
+ add_callback(callback_map['callbacks_app_started'], callback)
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
- add_callback(_callback_map['callbacks_model_loaded'], callback)
+ add_callback(callback_map['callbacks_model_loaded'], callback)
def on_ui_tabs(callback):
@@ -166,13 +166,13 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
- add_callback(_callback_map['callbacks_ui_tabs'], callback)
+ add_callback(callback_map['callbacks_ui_tabs'], callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
- add_callback(_callback_map['callbacks_ui_settings'], callback)
+ add_callback(callback_map['callbacks_ui_settings'], callback)
def on_before_image_saved(callback):
@@ -180,7 +180,7 @@ def on_before_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
"""
- add_callback(_callback_map['callbacks_before_image_saved'], callback)
+ add_callback(callback_map['callbacks_before_image_saved'], callback)
def on_image_saved(callback):
@@ -188,7 +188,7 @@ def on_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
- add_callback(_callback_map['callbacks_image_saved'], callback)
+ add_callback(callback_map['callbacks_image_saved'], callback)
def on_cfg_denoiser(callback):
@@ -196,4 +196,4 @@ def on_cfg_denoiser(callback):
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
- add_callback(_callback_map['callbacks_cfg_denoiser'], callback)
+ add_callback(callback_map['callbacks_cfg_denoiser'], callback)
--
cgit v1.2.3
From 0d7e01d9950e013784c4b77c05aa7583ea69edc8 Mon Sep 17 00:00:00 2001
From: innovaciones
Date: Fri, 4 Nov 2022 12:14:32 -0600
Subject: Open extensions links in new tab
Fixed for "Available" tab
---
modules/ui_extensions.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index a81de9a7..8e0d41d5 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -188,7 +188,7 @@ def refresh_available_extensions_from_data():
code += f"""
- {html.escape(name)}
+ {html.escape(name)}
{html.escape(description)}
{install_code}
--
cgit v1.2.3
From b8435e632f7ba0da12a2c8e9c788dda519279d24 Mon Sep 17 00:00:00 2001
From: evshiron
Date: Sat, 5 Nov 2022 02:36:47 +0800
Subject: add --cors-allow-origins cmd opt
---
modules/shared.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index a9e28b9c..e83cbcdf 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -86,6 +86,7 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
+parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None)
cmd_opts = parser.parse_args()
restricted_opts = {
@@ -147,9 +148,9 @@ class State:
self.interrupted = True
def nextjob(self):
- if opts.show_progress_every_n_steps == -1:
+ if opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
-
+
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
@@ -198,7 +199,7 @@ class State:
return
if self.current_latent is None:
return
-
+
if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
else:
--
cgit v1.2.3
From 467d8b967b5d1b1984ab113bec3fff217736e7ac Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sat, 5 Nov 2022 04:24:42 +0900
Subject: Fix errors from commit f2b697 with --hide-ui-dir-config
https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/f2b69709eaff88fc3a2bd49585556ec0883bf5ea
---
modules/ui.py | 14 ++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 4c2829af..76ca9b07 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1446,17 +1446,19 @@ def create_ui(wrap_gradio_gpu_call):
continue
oldval = opts.data.get(key, None)
-
- setattr(opts, key, value)
-
+ try:
+ setattr(opts, key, value)
+ except RuntimeError:
+ continue
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
changed += 1
-
- opts.save(shared.config_filename)
-
+ try:
+ opts.save(shared.config_filename)
+ except RuntimeError:
+ return opts.dumpjson(), f'{changed} settings changed without save.'
return opts.dumpjson(), f'{changed} settings changed.'
def run_settings_single(value, key):
--
cgit v1.2.3
From 30b1bcc64e67ad50c5d3af3a6fe1bd1e9553f34e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 4 Nov 2022 22:56:18 +0300
Subject: fix upscale loop erroneously applied multiple times
---
modules/upscaler.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 83fde7ca..c4e6e6bd 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -57,10 +57,18 @@ class Upscaler:
self.scale = scale
dest_w = img.width * scale
dest_h = img.height * scale
+
for i in range(3):
- if img.width > dest_w and img.height > dest_h:
- break
+ shape = (img.width, img.height)
+
img = self.do_upscale(img, selected_model)
+
+ if shape == (img.width, img.height):
+ break
+
+ if img.width >= dest_w and img.height >= dest_h:
+ break
+
if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
--
cgit v1.2.3
From 6008c0773ea575353f9b87da8a58454e20cc7857 Mon Sep 17 00:00:00 2001
From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com>
Date: Fri, 4 Nov 2022 23:03:05 +0000
Subject: Add support for new DPM-Solver++ samplers
---
modules/sd_samplers.py | 4 ++++
1 file changed, 4 insertions(+)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index c7c414ef..7ece6556 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -29,6 +29,10 @@ samplers_k_diffusion = [
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
+ ('DPM-Solver++(2S) a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM-Solver++(2M)', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
+ ('DPM-Solver++(2S) Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM-Solver++(2M) Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
]
samplers_data_k_diffusion = [
--
cgit v1.2.3
From f92dc505a013af9e385c7edbdf97539be62503d6 Mon Sep 17 00:00:00 2001
From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com>
Date: Fri, 4 Nov 2022 23:12:48 +0000
Subject: Fix name
---
modules/sd_samplers.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 7ece6556..b28a2e4c 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -31,7 +31,7 @@ samplers_k_diffusion = [
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
('DPM-Solver++(2S) a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
('DPM-Solver++(2M)', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM-Solver++(2S) Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM-Solver++(2S) a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM-Solver++(2M) Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
]
--
cgit v1.2.3
From 1b6c2fc749e12f12bbee4705e65f217d23fa9072 Mon Sep 17 00:00:00 2001
From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com>
Date: Fri, 4 Nov 2022 23:28:13 +0000
Subject: Reorder samplers
---
modules/sd_samplers.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index b28a2e4c..1e88f7ee 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -24,13 +24,13 @@ samplers_k_diffusion = [
('Heun', 'sample_heun', ['k_heun'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
+ ('DPM-Solver++(2S) a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM-Solver++(2M)', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
- ('DPM-Solver++(2S) a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
- ('DPM-Solver++(2M)', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM-Solver++(2S) a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM-Solver++(2M) Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
]
--
cgit v1.2.3
From ebce0c57c78a3f22178e3a38938d19ec0dfb703d Mon Sep 17 00:00:00 2001
From: Billy Cao
Date: Sat, 5 Nov 2022 11:38:24 +0800
Subject: Use typing.Optional instead of | to add support for Python 3.9 and
below.
---
modules/api/models.py | 26 +++++++++++++-------------
1 file changed, 13 insertions(+), 13 deletions(-)
(limited to 'modules')
diff --git a/modules/api/models.py b/modules/api/models.py
index 2ae75f43..a44c5ddd 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,6 +1,6 @@
import inspect
from pydantic import BaseModel, Field, create_model
-from typing import Any, Optional, Union
+from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
@@ -185,22 +185,22 @@ _options = vars(parser)['_option_string_actions']
for key in _options:
if(_options[key].dest != 'help'):
flag = _options[key]
- _type = str
- if(_options[key].default != None): _type = type(_options[key].default)
+ _type = str
+ if _options[key].default is not None: _type = type(_options[key].default)
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel):
name: str = Field(title="Name")
- aliases: list[str] = Field(title="Aliases")
+ aliases: list[str] = Field(title="Aliases")
options: dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel):
name: str = Field(title="Name")
- model_name: str | None = Field(title="Model Name")
- model_path: str | None = Field(title="Path")
- model_url: str | None = Field(title="URL")
+ model_name: Optional[str] = Field(title="Model Name")
+ model_path: Optional[str] = Field(title="Path")
+ model_url: Optional[str] = Field(title="URL")
class SDModelItem(BaseModel):
title: str = Field(title="Title")
@@ -211,21 +211,21 @@ class SDModelItem(BaseModel):
class HypernetworkItem(BaseModel):
name: str = Field(title="Name")
- path: str | None = Field(title="Path")
+ path: Optional[str] = Field(title="Path")
class FaceRestorerItem(BaseModel):
name: str = Field(title="Name")
- cmd_dir: str | None = Field(title="Path")
+ cmd_dir: Optional[str] = Field(title="Path")
class RealesrganItem(BaseModel):
name: str = Field(title="Name")
- path: str | None = Field(title="Path")
- scale: int | None = Field(title="Scale")
+ path: Optional[str] = Field(title="Path")
+ scale: Optional[int] = Field(title="Scale")
class PromptStyleItem(BaseModel):
name: str = Field(title="Name")
- prompt: str | None = Field(title="Prompt")
- negative_prompt: str | None = Field(title="Negative Prompt")
+ prompt: Optional[str] = Field(title="Prompt")
+ negative_prompt: Optional[str] = Field(title="Negative Prompt")
class ArtistItem(BaseModel):
name: str = Field(title="Name")
--
cgit v1.2.3
From e9a5562b9b27a1a4f9c282637b111cefd9727a41 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Sat, 5 Nov 2022 04:06:51 -0500
Subject: add support for tls (gradio tls options)
---
modules/shared.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 962115f6..7a20c3af 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -86,6 +86,9 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
+parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
+parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
+parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
cmd_opts = parser.parse_args()
restricted_opts = {
--
cgit v1.2.3
From 03b08c4a6b0609f24ec789d40100529b92ef0612 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 5 Nov 2022 15:04:48 +0300
Subject: do not die when an extension's repo has no remote
---
modules/extensions.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/extensions.py b/modules/extensions.py
index 897af96e..8e0977fd 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -34,8 +34,11 @@ class Extension:
if repo is None or repo.bare:
self.remote = None
else:
- self.remote = next(repo.remote().urls, None)
- self.status = 'unknown'
+ try:
+ self.remote = next(repo.remote().urls, None)
+ self.status = 'unknown'
+ except Exception:
+ self.remote = None
def list_files(self, subdir, extension):
from modules import scripts
--
cgit v1.2.3
From a170e3d22231e145f42bb878a76ae5f76fdca230 Mon Sep 17 00:00:00 2001
From: Evgeniy
Date: Sat, 5 Nov 2022 17:06:56 +0300
Subject: Python 3.8 typing compatibility
Solves problems with
```Traceback (most recent call last):
File "webui.py", line 201, in
webui()
File "webui.py", line 178, in webui
create_api(app)
File "webui.py", line 117, in create_api
from modules.api.api import Api
File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\api.py", line 9, in
from modules.api.models import *
File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\models.py", line 194, in
class SamplerItem(BaseModel):
File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\models.py", line 196, in SamplerItem
aliases: list[str] = Field(title="Aliases")
TypeError: 'type' object is not subscriptable```
and
```Traceback (most recent call last):
File "webui.py", line 201, in
webui()
File "webui.py", line 178, in webui
create_api(app)
File "webui.py", line 117, in create_api
from modules.api.api import Api
File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\api.py", line 9, in
from modules.api.models import *
File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\models.py", line 194, in
class SamplerItem(BaseModel):
File "H:\AIart\stable-diffusion\stable-diffusion-webui\modules\api\models.py", line 197, in SamplerItem
options: dict[str, str] = Field(title="Options")
TypeError: 'type' object is not subscriptable```
---
modules/api/models.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/api/models.py b/modules/api/models.py
index a44c5ddd..f89da1ff 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -5,7 +5,7 @@ from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers, opts, parser
-from typing import List
+from typing import Dict, List
API_NOT_ALLOWED = [
"self",
@@ -193,8 +193,8 @@ FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel):
name: str = Field(title="Name")
- aliases: list[str] = Field(title="Aliases")
- options: dict[str, str] = Field(title="Options")
+ aliases: List[str] = Field(title="Aliases")
+ options: Dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel):
name: str = Field(title="Name")
@@ -230,4 +230,4 @@ class PromptStyleItem(BaseModel):
class ArtistItem(BaseModel):
name: str = Field(title="Name")
score: float = Field(title="Score")
- category: str = Field(title="Category")
\ No newline at end of file
+ category: str = Field(title="Category")
--
cgit v1.2.3
From 62e3d71aa778928d63cab81d9d8cde33e55bebb3 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 5 Nov 2022 17:09:42 +0300
Subject: rework the code to not use the walrus operator because colab's 3.7
does not support it
---
modules/hypernetworks/hypernetwork.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 5ceed6ee..7f182712 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -429,13 +429,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
+
# Here we use optimizer from saved HN, or we can specify as UI option.
- if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict:
+ if hypernetwork.optimizer_name in optimizer_dict:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
+ optimizer_name = hypernetwork.optimizer_name
else:
- print(f"Optimizer type {optimizer_name} is not defined!")
+ print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
+
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
--
cgit v1.2.3
From 159475e072f2ed3db8235aab9c3fa18640b93b80 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 5 Nov 2022 18:32:22 +0300
Subject: tweak names a bit for new samplers
---
modules/sd_samplers.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 1e88f7ee..783992d2 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -24,15 +24,15 @@ samplers_k_diffusion = [
('Heun', 'sample_heun', ['k_heun'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
- ('DPM-Solver++(2S) a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
- ('DPM-Solver++(2M)', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
- ('DPM-Solver++(2S) a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
- ('DPM-Solver++(2M) Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
]
samplers_data_k_diffusion = [
--
cgit v1.2.3
From 99b05addb1c98169d78957f13efef308aef0af94 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sat, 5 Nov 2022 18:46:47 -0300
Subject: Fix options endpoint not showing the full list of options
---
modules/api/models.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/api/models.py b/modules/api/models.py
index f89da1ff..0ea62155 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -168,9 +168,9 @@ class ProgressResponse(BaseModel):
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
fields = {}
-for key, value in opts.data.items():
- metadata = opts.data_labels.get(key)
- optType = opts.typemap.get(type(value), type(value))
+for key, metadata in opts.data_labels.items():
+ value = opts.data.get(key)
+ optType = opts.typemap.get(type(metadata.default), type(value))
if (metadata is not None):
fields.update({key: (Optional[optType], Field(
--
cgit v1.2.3
From 0ebf66b575f008a027097946eb2f6845feffd010 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sat, 5 Nov 2022 18:58:19 -0300
Subject: Fix set config endpoint
---
modules/api/api.py | 12 ++++--------
1 file changed, 4 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 112000b8..a924c83a 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -230,14 +230,10 @@ class Api:
return options
- def set_config(self, req: OptionsModel):
- # currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
- # overwrite all options with default values.
- raise RuntimeError('Setting options via API is not supported')
-
- reqDict = vars(req)
- for o in reqDict:
- setattr(shared.opts, o, reqDict[o])
+ def set_config(self, req: Dict[str, Any]):
+
+ for o in req:
+ setattr(shared.opts, o, req[o])
shared.opts.save(shared.config_filename)
return
--
cgit v1.2.3
From 3c72055c22425dcde0739b5246e3501f4a3ec794 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sat, 5 Nov 2022 19:05:15 -0300
Subject: Add skip endpoint
---
modules/api/api.py | 6 ++++++
1 file changed, 6 insertions(+)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index a924c83a..c7ceb787 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -64,6 +64,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
@@ -219,6 +220,11 @@ class Api:
return {}
+ def skip(self):
+ shared.state.skip()
+
+ return
+
def get_config(self):
options = {}
for key in shared.opts.data.keys():
--
cgit v1.2.3
From 7f63980e479c7ffaec907fb659b5024e96eb72e7 Mon Sep 17 00:00:00 2001
From: Bruno Seoane
Date: Sat, 5 Nov 2022 19:09:13 -0300
Subject: Remove unnecesary return
---
modules/api/api.py | 2 --
1 file changed, 2 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index c7ceb787..33e6c6dc 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -223,8 +223,6 @@ class Api:
def skip(self):
shared.state.skip()
- return
-
def get_config(self):
options = {}
for key in shared.opts.data.keys():
--
cgit v1.2.3
From 6603f63b7b8af39ab815091460c5c2a12d3f253e Mon Sep 17 00:00:00 2001
From: Han Lin
Date: Sun, 6 Nov 2022 11:08:20 +0800
Subject: Fixes LDSR upscaler producing black bars
---
modules/ldsr_model_arch.py | 14 +++++++++++---
1 file changed, 11 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py
index 14db5076..90e0a2f0 100644
--- a/modules/ldsr_model_arch.py
+++ b/modules/ldsr_model_arch.py
@@ -101,8 +101,8 @@ class LDSR:
down_sample_rate = target_scale / 4
wd = width_og * down_sample_rate
hd = height_og * down_sample_rate
- width_downsampled_pre = int(wd)
- height_downsampled_pre = int(hd)
+ width_downsampled_pre = int(np.ceil(wd))
+ height_downsampled_pre = int(np.ceil(hd))
if down_sample_rate != 1:
print(
@@ -110,7 +110,12 @@ class LDSR:
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
- logs = self.run(model["model"], im_og, diffusion_steps, eta)
+
+ # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
+ pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
+ im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
+
+ logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"]
sample = sample.detach().cpu()
@@ -120,6 +125,9 @@ class LDSR:
sample = np.transpose(sample, (0, 2, 3, 1))
a = Image.fromarray(sample[0])
+ # remove padding
+ a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
+
del model
gc.collect()
torch.cuda.empty_cache()
--
cgit v1.2.3
From a2a1a2f7270a865175f64475229838a8d64509ea Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 6 Nov 2022 09:02:25 +0300
Subject: add ability to create extensions that add localizations
---
modules/localization.py | 6 ++++++
modules/scripts.py | 1 -
modules/shared.py | 2 --
modules/ui.py | 3 +--
4 files changed, 7 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/localization.py b/modules/localization.py
index b1810cda..f6a6f2fb 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -3,6 +3,7 @@ import os
import sys
import traceback
+
localizations = {}
@@ -16,6 +17,11 @@ def list_localizations(dirname):
localizations[fn] = os.path.join(dirname, file)
+ from modules import scripts
+ for file in scripts.list_scripts("localizations", ".json"):
+ fn, ext = os.path.splitext(file.filename)
+ localizations[fn] = file.path
+
def localization_js(current_localization_name):
fn = localizations.get(current_localization_name, None)
diff --git a/modules/scripts.py b/modules/scripts.py
index 366c90d7..637b2329 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -3,7 +3,6 @@ import sys
import traceback
from collections import namedtuple
-import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
diff --git a/modules/shared.py b/modules/shared.py
index 70b998ff..e8bacd3c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -221,8 +221,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
-localization.list_localizations(cmd_opts.localizations_dir)
-
def realesrgan_models_names():
import modules.realesrgan_model
diff --git a/modules/ui.py b/modules/ui.py
index 76ca9b07..23643c22 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1563,11 +1563,10 @@ def create_ui(wrap_gradio_gpu_call):
shared.state.need_restart = True
restart_gradio.click(
-
fn=request_restart,
+ _js='restart_reload',
inputs=[],
outputs=[],
- _js='restart_reload'
)
if column is not None:
--
cgit v1.2.3
From e5b4e3f820cd09e751f1d168ab05d606d078a0d9 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 6 Nov 2022 10:12:53 +0300
Subject: add tags to extensions, and ability to filter out tags list changed
Settings keys in UI do not print VRAM/etc stats everywhere but in calls that
use GPU
---
modules/ui.py | 25 ++++++++++++----------
modules/ui_extensions.py | 55 ++++++++++++++++++++++++++++++++++++++----------
2 files changed, 58 insertions(+), 22 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 23643c22..c946ad59 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -174,9 +174,9 @@ def save_pil_to_file(pil_image, dir=None):
gr.processing_utils.save_pil_to_file = save_pil_to_file
-def wrap_gradio_call(func, extra_outputs=None):
+def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
- run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
+ run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
@@ -203,11 +203,18 @@ def wrap_gradio_call(func, extra_outputs=None):
res = extra_outputs_array + [f"{plaintext_to_html(type(e).__name__+': '+str(e))}
"]
+ shared.state.skipped = False
+ shared.state.interrupted = False
+ shared.state.job_count = 0
+
+ if not add_stats:
+ return tuple(res)
+
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
- if (elapsed_m > 0):
+ if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
@@ -225,10 +232,6 @@ def wrap_gradio_call(func, extra_outputs=None):
# last item is always HTML
res[-1] += f""
- shared.state.skipped = False
- shared.state.interrupted = False
- shared.state.job_count = 0
-
return tuple(res)
return f
@@ -1436,7 +1439,7 @@ def create_ui(wrap_gradio_gpu_call):
opts.reorder()
def run_settings(*args):
- changed = 0
+ changed = []
for key, value, comp in zip(opts.data_labels.keys(), args, components):
assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
@@ -1454,12 +1457,12 @@ def create_ui(wrap_gradio_gpu_call):
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
- changed += 1
+ changed.append(key)
try:
opts.save(shared.config_filename)
except RuntimeError:
- return opts.dumpjson(), f'{changed} settings changed without save.'
- return opts.dumpjson(), f'{changed} settings changed.'
+ return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
+ return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.'
def run_settings_single(value, key):
if not opts.same_type(value, opts.data_labels[key].default):
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 8e0d41d5..02ab9643 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -140,13 +140,15 @@ def install_extension_from_url(dirname, url):
shutil.rmtree(tmpdir, True)
-def install_extension_from_index(url):
+def install_extension_from_index(url, hide_tags):
ext_table, message = install_extension_from_url(None, url)
- return refresh_available_extensions_from_data(), ext_table, message
+ code, _ = refresh_available_extensions_from_data(hide_tags)
+ return code, ext_table, message
-def refresh_available_extensions(url):
+
+def refresh_available_extensions(url, hide_tags):
global available_extensions
import urllib.request
@@ -155,13 +157,25 @@ def refresh_available_extensions(url):
available_extensions = json.loads(text)
- return url, refresh_available_extensions_from_data(), ''
+ code, tags = refresh_available_extensions_from_data(hide_tags)
+
+ return url, code, gr.CheckboxGroup.update(choices=tags), ''
+
+
+def refresh_available_extensions_for_tags(hide_tags):
+ code, _ = refresh_available_extensions_from_data(hide_tags)
+ return code, ''
-def refresh_available_extensions_from_data():
+
+def refresh_available_extensions_from_data(hide_tags):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
+ tags = available_extensions.get("tags", {})
+ tags_to_hide = set(hide_tags)
+ hidden = 0
+
code = f"""
@@ -178,17 +192,24 @@ def refresh_available_extensions_from_data():
name = ext.get("name", "noname")
url = ext.get("url", None)
description = ext.get("description", "")
+ extension_tags = ext.get("tags", [])
if url is None:
continue
+ if len([x for x in extension_tags if x in tags_to_hide]) > 0:
+ hidden += 1
+ continue
+
existing = installed_extension_urls.get(normalize_git_url(url), None)
install_code = f""" """
+ tags_text = ", ".join([f"{x} " for x in extension_tags])
+
code += f"""
- {html.escape(name)}
+ {html.escape(name)} {tags_text}
{html.escape(description)}
{install_code}
@@ -199,7 +220,10 @@ def refresh_available_extensions_from_data():
"""
- return code
+ if hidden > 0:
+ code += f"Extension hidden: {hidden}
"
+
+ return code, list(tags)
def create_ui():
@@ -238,21 +262,30 @@ def create_ui():
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
+ with gr.Row():
+ hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"])
+
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
- fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
- inputs=[available_extensions_index],
- outputs=[available_extensions_index, available_extensions_table, install_result],
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+ inputs=[available_extensions_index, hide_tags],
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
- inputs=[extension_to_install],
+ inputs=[extension_to_install, hide_tags],
outputs=[available_extensions_table, extensions_table, install_result],
)
+ hide_tags.change(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
+ inputs=[hide_tags],
+ outputs=[available_extensions_table, install_result]
+ )
+
with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
--
cgit v1.2.3
From 6e4de5b4422dfc0d45063b2c8c78b19f00321615 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 6 Nov 2022 11:20:23 +0300
Subject: add load_with_extra function for modules to load checkpoints with
extended whitelist
---
modules/safe.py | 40 +++++++++++++++++++++++++++++++++++++---
1 file changed, 37 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/safe.py b/modules/safe.py
index 348a24fc..a9209e38 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -23,11 +23,18 @@ def encode(*args):
class RestrictedUnpickler(pickle.Unpickler):
+ extra_handler = None
+
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
return TypedStorage()
def find_class(self, module, name):
+ if self.extra_handler is not None:
+ res = self.extra_handler(module, name)
+ if res is not None:
+ return res
+
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
@@ -52,7 +59,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return set
# Forbid everything else.
- raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
+ raise Exception(f"global '{module}/{name}' is forbidden")
allowed_zip_names = ["archive/data.pkl", "archive/version"]
@@ -69,7 +76,7 @@ def check_zip_filenames(filename, names):
raise Exception(f"bad file inside {filename}: {name}")
-def check_pt(filename):
+def check_pt(filename, extra_handler):
try:
# new pytorch format is a zip file
@@ -78,6 +85,7 @@ def check_pt(filename):
with z.open('archive/data.pkl') as file:
unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
unpickler.load()
except zipfile.BadZipfile:
@@ -85,16 +93,42 @@ def check_pt(filename):
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
for i in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
+ return load_with_extra(filename, *args, **kwargs)
+
+
+def load_with_extra(filename, extra_handler=None, *args, **kwargs):
+ """
+ this functon is intended to be used by extensions that want to load models with
+ some extra classes in them that the usual unpickler would find suspicious.
+
+ Use the extra_handler argument to specify a function that takes module and field name as text,
+ and returns that field's value:
+
+ ```python
+ def extra(module, name):
+ if module == 'collections' and name == 'OrderedDict':
+ return collections.OrderedDict
+
+ return None
+
+ safe.load_with_extra('model.pt', extra_handler=extra)
+ ```
+
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
+ definitely unsafe.
+ """
+
from modules import shared
try:
if not shared.cmd_opts.disable_safe_unpickle:
- check_pt(filename)
+ check_pt(filename, extra_handler)
except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
--
cgit v1.2.3
From 55ca04095845b41bf66333b3b7343e3ea0babed1 Mon Sep 17 00:00:00 2001
From: Billy Cao
Date: Sun, 6 Nov 2022 16:31:44 +0800
Subject: Resolve conflict
---
modules/processing.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 86d015af..db35983b 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -422,14 +422,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
- opts.data[k] = v # we don't call onchange for simplicity which makes changing model impossible
+ setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model impossible
if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not
res = process_images_inner(p)
finally: # restore opts to original state
for k, v in stored_opts.items():
- opts.data[k] = v
+ setattr(opts, k, v)
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
return res
--
cgit v1.2.3
From 32c0eab89538ba3900bf499291720f80ae4b43e5 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 6 Nov 2022 14:39:41 +0300
Subject: load all settings in one call instead of one by one when the page
loads
---
modules/ui.py | 22 ++++++++++++++++------
1 file changed, 16 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index c946ad59..34c31ef1 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1141,7 +1141,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[html, generation_info, html2],
)
- with gr.Blocks() as modelmerger_interface:
+ with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="A merger of the two checkpoints will be generated in your checkpoint directory.
")
@@ -1161,7 +1161,7 @@ def create_ui(wrap_gradio_gpu_call):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
- with gr.Blocks() as train_interface:
+ with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
gr.HTML(value="See wiki for detailed explanation.
")
@@ -1420,15 +1420,14 @@ def create_ui(wrap_gradio_gpu_call):
if info.refresh is not None:
if is_quicksettings:
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
with gr.Row(variant="compact"):
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
-
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
return res
@@ -1639,6 +1638,17 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[component, text_settings],
)
+ component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
+
+ def get_settings_values():
+ return [getattr(opts, key) for key in component_keys]
+
+ demo.load(
+ fn=get_settings_values,
+ inputs=[],
+ outputs=[component_dict[k] for k in component_keys],
+ )
+
def modelmerger(*args):
try:
results = modules.extras.run_modelmerger(*args)
--
cgit v1.2.3
From 67c8e11be74180be19341aebbd6a246c37a79fbb Mon Sep 17 00:00:00 2001
From: snowmeow2
Date: Mon, 7 Nov 2022 02:32:06 +0800
Subject: Adding DeepDanbooru to the interrogation API
---
modules/api/api.py | 16 ++++++++++++++--
modules/api/models.py | 1 +
2 files changed, 15 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 688469ad..596a6616 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -15,6 +15,9 @@ from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from typing import List
+if shared.cmd_opts.deepdanbooru:
+ from modules.deepbooru import get_deepbooru_tags
+
def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
@@ -220,11 +223,20 @@ class Api:
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
- img = self.__base64_to_image(image_b64)
+ img = decode_base64_to_image(image_b64)
+ img = img.convert('RGB')
# Override object param
with self.queue_lock:
- processed = shared.interrogator.interrogate(img)
+ if interrogatereq.model == "clip":
+ processed = shared.interrogator.interrogate(img)
+ elif interrogatereq.model == "deepdanbooru":
+ if shared.cmd_opts.deepdanbooru:
+ processed = get_deepbooru_tags(img)
+ else:
+ raise HTTPException(status_code=404, detail="Model not found. Add --deepdanbooru when launching for using the model.")
+ else:
+ raise HTTPException(status_code=404, detail="Model not found")
return InterrogateResponse(caption=processed)
diff --git a/modules/api/models.py b/modules/api/models.py
index 34dbfa16..f9cd929e 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -170,6 +170,7 @@ class ProgressResponse(BaseModel):
class InterrogateRequest(BaseModel):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+ model: str = Field(default="clip", title="Model", description="The interrogate model used.")
class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
--
cgit v1.2.3
From cd6c55c1ab14fcab15329cde599cf79e8d555657 Mon Sep 17 00:00:00 2001
From: pepe10-gpu
Date: Sun, 6 Nov 2022 17:05:51 -0800
Subject: 16xx card fix
cudnn
---
modules/devices.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules')
diff --git a/modules/devices.py b/modules/devices.py
index 7511e1dc..858bf399 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -39,10 +39,13 @@ def torch_gc():
def enable_tf32():
if torch.cuda.is_available():
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.enabled = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
+
errors.run(enable_tf32, "Enabling TF32")
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
--
cgit v1.2.3
From a258fd60dbe2d68325339405a2aa72816d06d2fd Mon Sep 17 00:00:00 2001
From: Keavon Chambers
Date: Mon, 7 Nov 2022 00:13:58 -0800
Subject: Add CORS-allow policy launch argument using regex
---
modules/shared.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index e8bacd3c..55de286d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -81,12 +81,13 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
-parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
-parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
+parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
+parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
-parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None)
+parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
+parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
--
cgit v1.2.3
From 9ed4a126bd6421f91bf4a9bdd348b6aef0a378c6 Mon Sep 17 00:00:00 2001
From: kavorite
Date: Mon, 7 Nov 2022 19:58:49 -0500
Subject: add gradio-inpaint-tool; color-sketch
---
modules/img2img.py | 19 +++++++++++++------
modules/shared.py | 1 +
modules/ui.py | 11 ++++++++++-
3 files changed, 24 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/img2img.py b/modules/img2img.py
index be9f3653..00c6f827 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -59,18 +59,25 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_inpaint = mode == 1
is_batch = mode == 2
if is_inpaint:
# Drawn mask
if mask_mode == 0:
- image = init_img_with_mask['image']
- mask = init_img_with_mask['mask']
- alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
- image = image.convert('RGB')
+ image = init_img_with_mask
+ is_mask_sketch = isinstance(image, dict)
+ if is_mask_sketch:
+ # Sketch: mask iff. not transparent
+ image, mask = image["image"], image["mask"]
+ mask = np.array(mask)[..., -1] > 0
+ else:
+ # Color-sketch: mask iff. painted over
+ orig = init_img_with_mask_orig or image
+ mask = np.any(np.array(image) != np.array(orig), axis=-1)
+ mask = Image.fromarray(mask.astype(np.uint8) * 255, "L")
+ image = image.convert("RGB")
# Uploaded mask
else:
image = init_img_inpaint
diff --git a/modules/shared.py b/modules/shared.py
index d8e99f85..325e37d9 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -71,6 +71,7 @@ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
+parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
diff --git a/modules/ui.py b/modules/ui.py
index 2609857e..db323e9c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -840,8 +840,17 @@ def create_ui(wrap_gradio_gpu_call):
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
with gr.TabItem('Inpaint', id='inpaint'):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
+ init_img_with_mask_orig = gr.State(None)
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
+ def update_orig(image, state):
+ if image is not None:
+ same_size = state is not None and state.size == image.size
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+ edited = same_size and has_exact_match
+ return image if not edited or state is None else state
+
+ init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
--
cgit v1.2.3
From 29eff4a194d22f0f0e7a7a976d746a71a4193cf5 Mon Sep 17 00:00:00 2001
From: pepe10-gpu
Date: Mon, 7 Nov 2022 18:06:48 -0800
Subject: terrible hack
---
modules/devices.py | 11 +++++++++--
1 file changed, 9 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/devices.py b/modules/devices.py
index 858bf399..4c63f465 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -39,8 +39,15 @@ def torch_gc():
def enable_tf32():
if torch.cuda.is_available():
- torch.backends.cudnn.benchmark = True
- torch.backends.cudnn.enabled = True
+ #TODO: make this better; find a way to check if it is a turing card
+ turing = ["1630","1650","1660","Quadro RTX 3000","Quadro RTX 4000","Quadro RTX 4000","Quadro RTX 5000","Quadro RTX 5000","Quadro RTX 6000","Quadro RTX 6000","Quadro RTX 8000","Quadro RTX T400","Quadro RTX T400","Quadro RTX T600","Quadro RTX T1000","Quadro RTX T1000","2060","2070","2080","Titan RTX","Tesla T4","MX450","MX550"]
+ for devid in range(0,torch.cuda.device_count()):
+ for i in turing:
+ if i in torch.cuda.get_device_name(devid):
+ shd = True
+ if shd:
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.enabled = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
--
cgit v1.2.3
From c5334fc56b3d44976425da2e6d0a303ae96836a1 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 8 Nov 2022 08:35:01 +0300
Subject: fix javascript duplication bug after pressing the restart UI button
---
modules/ui.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 34c31ef1..67cf1d6a 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1752,7 +1752,7 @@ def create_ui(wrap_gradio_gpu_call):
return demo
-def load_javascript(raw_response):
+def reload_javascript():
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f''
@@ -1768,7 +1768,7 @@ def load_javascript(raw_response):
javascript += f"\n"
def template_response(*args, **kwargs):
- res = raw_response(*args, **kwargs)
+ res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(
b'', f'{javascript}'.encode("utf8"))
res.init_headers()
@@ -1777,4 +1777,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response
-reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
+if not hasattr(shared, 'GradioTemplateResponseOriginal'):
+ shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
--
cgit v1.2.3
From 8011be33c36eb7aa9e9498fc714614034e07f67a Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 8 Nov 2022 08:37:05 +0300
Subject: move functions out of main body for image preprocessing for easier
hijacking
---
modules/textual_inversion/preprocess.py | 162 ++++++++++++++++++--------------
1 file changed, 93 insertions(+), 69 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index e13b1894..488aa5b5 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -35,6 +35,84 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
deepbooru.release_process()
+def listfiles(dirname):
+ return os.listdir(dirname)
+
+
+class PreprocessParams:
+ src = None
+ dstdir = None
+ subindex = 0
+ flip = False
+ process_caption = False
+ process_caption_deepbooru = False
+ preprocess_txt_action = None
+
+
+def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
+ caption = ""
+
+ if params.process_caption:
+ caption += shared.interrogator.generate_caption(image)
+
+ if params.process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.get_tags_from_process(image)
+
+ filename_part = params.src
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
+
+ basename = f"{index:05}-{params.subindex}-{filename_part}"
+ image.save(os.path.join(params.dstdir, f"{basename}.png"))
+
+ if params.preprocess_txt_action == 'prepend' and existing_caption:
+ caption = existing_caption + ' ' + caption
+ elif params.preprocess_txt_action == 'append' and existing_caption:
+ caption = caption + ' ' + existing_caption
+ elif params.preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
+
+ caption = caption.strip()
+
+ if len(caption) > 0:
+ with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
+ file.write(caption)
+
+ params.subindex += 1
+
+
+def save_pic(image, index, params, existing_caption=None):
+ save_pic_with_caption(image, index, params, existing_caption=existing_caption)
+
+ if params.flip:
+ save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
+
+
+def split_pic(image, inverse_xy, width, height, overlap_ratio):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
+ else:
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
+
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width
@@ -48,82 +126,28 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
os.makedirs(dst, exist_ok=True)
- files = os.listdir(src)
+ files = listfiles(src)
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- def save_pic_with_caption(image, index, existing_caption=None):
- caption = ""
-
- if process_caption:
- caption += shared.interrogator.generate_caption(image)
-
- if process_caption_deepbooru:
- if len(caption) > 0:
- caption += ", "
- caption += deepbooru.get_tags_from_process(image)
-
- filename_part = filename
- filename_part = os.path.splitext(filename_part)[0]
- filename_part = os.path.basename(filename_part)
-
- basename = f"{index:05}-{subindex[0]}-{filename_part}"
- image.save(os.path.join(dst, f"{basename}.png"))
-
- if preprocess_txt_action == 'prepend' and existing_caption:
- caption = existing_caption + ' ' + caption
- elif preprocess_txt_action == 'append' and existing_caption:
- caption = caption + ' ' + existing_caption
- elif preprocess_txt_action == 'copy' and existing_caption:
- caption = existing_caption
-
- caption = caption.strip()
-
- if len(caption) > 0:
- with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
- file.write(caption)
-
- subindex[0] += 1
-
- def save_pic(image, index, existing_caption=None):
- save_pic_with_caption(image, index, existing_caption=existing_caption)
-
- if process_flip:
- save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
-
- def split_pic(image, inverse_xy):
- if inverse_xy:
- from_w, from_h = image.height, image.width
- to_w, to_h = height, width
- else:
- from_w, from_h = image.width, image.height
- to_w, to_h = width, height
- h = from_h * to_w // from_w
- if inverse_xy:
- image = image.resize((h, to_w))
- else:
- image = image.resize((to_w, h))
-
- split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
- y_step = (h - to_h) / (split_count - 1)
- for i in range(split_count):
- y = int(y_step * i)
- if inverse_xy:
- splitted = image.crop((y, 0, y + to_h, to_w))
- else:
- splitted = image.crop((0, y, to_w, y + to_h))
- yield splitted
-
+ params = PreprocessParams()
+ params.dstdir = dst
+ params.flip = process_flip
+ params.process_caption = process_caption
+ params.process_caption_deepbooru = process_caption_deepbooru
+ params.preprocess_txt_action = preprocess_txt_action
for index, imagefile in enumerate(tqdm.tqdm(files)):
- subindex = [0]
+ params.subindex = 0
filename = os.path.join(src, imagefile)
try:
img = Image.open(filename).convert("RGB")
except Exception:
continue
+ params.src = filename
+
existing_caption = None
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
if os.path.exists(existing_caption_filename):
@@ -143,8 +167,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
process_default_resize = True
if process_split and ratio < 1.0 and ratio <= split_threshold:
- for splitted in split_pic(img, inverse_xy):
- save_pic(splitted, index, existing_caption=existing_caption)
+ for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
+ save_pic(splitted, index, params, existing_caption=existing_caption)
process_default_resize = False
if process_focal_crop and img.height != img.width:
@@ -165,11 +189,11 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
dnn_model_path = dnn_model_path,
)
for focal in autocrop.crop_image(img, autocrop_settings):
- save_pic(focal, index, existing_caption=existing_caption)
+ save_pic(focal, index, params, existing_caption=existing_caption)
process_default_resize = False
if process_default_resize:
img = images.resize_image(1, img, width, height)
- save_pic(img, index, existing_caption=existing_caption)
+ save_pic(img, index, params, existing_caption=existing_caption)
- shared.state.nextjob()
\ No newline at end of file
+ shared.state.nextjob()
--
cgit v1.2.3
From 1610b3258458025025e9c4faae57d290e4519745 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 8 Nov 2022 08:38:10 +0300
Subject: add callback for creating a tab in train UI
---
modules/script_callbacks.py | 27 +++++++++++++++++++++++++--
modules/ui.py | 4 ++++
2 files changed, 29 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 74dfb880..f19e164c 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -7,6 +7,7 @@ from typing import Optional
from fastapi import FastAPI
from gradio import Blocks
+
def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@@ -45,15 +46,21 @@ class CFGDenoiserParams:
"""Total number of sampling steps planned"""
+class UiTrainTabParams:
+ def __init__(self, txt2img_preview_params):
+ self.txt2img_preview_params = txt2img_preview_params
+
+
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict(
callbacks_app_started=[],
callbacks_model_loaded=[],
callbacks_ui_tabs=[],
+ callbacks_ui_train_tabs=[],
callbacks_ui_settings=[],
callbacks_before_image_saved=[],
callbacks_image_saved=[],
- callbacks_cfg_denoiser=[]
+ callbacks_cfg_denoiser=[],
)
@@ -61,6 +68,7 @@ def clear_callbacks():
for callback_list in callback_map.values():
callback_list.clear()
+
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']:
try:
@@ -79,7 +87,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
-
+
for c in callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
@@ -89,6 +97,14 @@ def ui_tabs_callback():
return res
+def ui_train_tabs_callback(params: UiTrainTabParams):
+ for c in callback_map['callbacks_ui_train_tabs']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'callbacks_ui_train_tabs')
+
+
def ui_settings_callback():
for c in callback_map['callbacks_ui_settings']:
try:
@@ -169,6 +185,13 @@ def on_ui_tabs(callback):
add_callback(callback_map['callbacks_ui_tabs'], callback)
+def on_ui_train_tabs(callback):
+ """register a function to be called when the UI is creating new tabs for the train tab.
+ Create your new tabs with gr.Tab.
+ """
+ add_callback(callback_map['callbacks_ui_train_tabs'], callback)
+
+
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
diff --git a/modules/ui.py b/modules/ui.py
index 67cf1d6a..7ea1177f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1270,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
train_embedding = gr.Button(value="Train Embedding", variant='primary')
+ params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
+
+ script_callbacks.ui_train_tabs_callback(params)
+
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
--
cgit v1.2.3
From c34542a48376e4972de955aab00ffc8359f7d792 Mon Sep 17 00:00:00 2001
From: kavorite
Date: Tue, 8 Nov 2022 03:25:59 -0500
Subject: add new color-sketch state to img2img invocation
---
modules/ui.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index db323e9c..29954f2a 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -941,6 +941,7 @@ def create_ui(wrap_gradio_gpu_call):
img2img_prompt_style2,
init_img,
init_img_with_mask,
+ init_img_with_mask_orig,
init_img_inpaint,
init_mask_inpaint,
mask_mode,
--
cgit v1.2.3
From cfcadeae9a61e1aff32960864f90299412c86d5c Mon Sep 17 00:00:00 2001
From: d8ahazard
Date: Tue, 8 Nov 2022 10:03:56 -0600
Subject: Add option to preload extensions
By creating a file called "preload.py" in an extension folder and declaring a preload(parser) method, we can add extra command-line args for an extension.
---
modules/extensions.py | 23 ++++++++++++++++++++++-
modules/shared.py | 5 ++++-
2 files changed, 26 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/extensions.py b/modules/extensions.py
index 8e0977fd..544f3580 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,12 +1,12 @@
import os
import sys
import traceback
+from importlib.machinery import SourceFileLoader
import git
from modules import paths, shared
-
extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions")
@@ -84,3 +84,24 @@ def list_extensions():
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
extensions.append(extension)
+
+
+def preload_extensions(parser):
+ if not os.path.isdir(extensions_dir):
+ return
+
+ for dirname in sorted(os.listdir(extensions_dir)):
+ path = os.path.join(extensions_dir, dirname)
+ if not os.path.isdir(path):
+ continue
+ for file in os.listdir(path):
+ if "preload.py" in file:
+ full_file = os.path.join(path, file)
+ print(f"Got preload file: {full_file}")
+
+ try:
+ ext = SourceFileLoader("preload", full_file).load_module()
+ parser = ext.preload(parser)
+ except Exception as e:
+ print(f"Exception preloading script: {e}")
+ return parser
\ No newline at end of file
diff --git a/modules/shared.py b/modules/shared.py
index e8bacd3c..222ad4fb 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -15,7 +15,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers, sd_models, localization, sd_vae
+from modules import sd_samplers, sd_models, localization, sd_vae, extensions
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@@ -91,7 +91,10 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
+extensions.preload_extensions(parser)
+
cmd_opts = parser.parse_args()
+
restricted_opts = {
"samples_filename_pattern",
"directories_filename_pattern",
--
cgit v1.2.3
From 62e9fec3df8518da3a2c35fa090bb54946c856b2 Mon Sep 17 00:00:00 2001
From: pepe10-gpu
Date: Tue, 8 Nov 2022 15:19:09 -0800
Subject: actual better fix
thanks C43H66N12O12S2
---
modules/devices.py | 7 ++-----
1 file changed, 2 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/devices.py b/modules/devices.py
index 4c63f465..058a5e00 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -39,12 +39,9 @@ def torch_gc():
def enable_tf32():
if torch.cuda.is_available():
- #TODO: make this better; find a way to check if it is a turing card
- turing = ["1630","1650","1660","Quadro RTX 3000","Quadro RTX 4000","Quadro RTX 4000","Quadro RTX 5000","Quadro RTX 5000","Quadro RTX 6000","Quadro RTX 6000","Quadro RTX 8000","Quadro RTX T400","Quadro RTX T400","Quadro RTX T600","Quadro RTX T1000","Quadro RTX T1000","2060","2070","2080","Titan RTX","Tesla T4","MX450","MX550"]
for devid in range(0,torch.cuda.device_count()):
- for i in turing:
- if i in torch.cuda.get_device_name(devid):
- shd = True
+ if torch.cuda.get_device_capability(devid) == (7, 5):
+ shd = True
if shd:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
--
cgit v1.2.3
From 59bb1d36ea69db449cfe23be4988ab4f6711bf4b Mon Sep 17 00:00:00 2001
From: kavorite
Date: Tue, 8 Nov 2022 22:06:29 -0500
Subject: blur mask with color-sketch + add paint transparency slider
---
modules/img2img.py | 21 +++++++++++++--------
modules/ui.py | 3 +++
2 files changed, 16 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/img2img.py b/modules/img2img.py
index 00c6f827..644297da 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -4,7 +4,7 @@ import sys
import traceback
import numpy as np
-from PIL import Image, ImageOps, ImageChops
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance
from modules import devices
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
@@ -40,7 +40,7 @@ def process_batch(p, input_dir, output_dir, args):
img = Image.open(image)
# Use the EXIF orientation of photos taken by smartphones.
- img = ImageOps.exif_transpose(img)
+ img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
proc = modules.scripts.scripts_img2img.run(p, *args)
@@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_inpaint = mode == 1
is_batch = mode == 2
@@ -68,15 +68,20 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if mask_mode == 0:
image = init_img_with_mask
is_mask_sketch = isinstance(image, dict)
- if is_mask_sketch:
+ is_mask_paint = not is_mask_sketch
+ if is_mask_sketch:
# Sketch: mask iff. not transparent
image, mask = image["image"], image["mask"]
- mask = np.array(mask)[..., -1] > 0
+ pred = np.array(mask)[..., -1] > 0
else:
# Color-sketch: mask iff. painted over
orig = init_img_with_mask_orig or image
- mask = np.any(np.array(image) != np.array(orig), axis=-1)
- mask = Image.fromarray(mask.astype(np.uint8) * 255, "L")
+ pred = np.any(np.array(image) != np.array(orig), axis=-1)
+ mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
+ if is_mask_paint:
+ mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
+ blur = ImageFilter.GaussianBlur(mask_blur)
+ image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB")
# Uploaded mask
else:
@@ -89,7 +94,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
# Use the EXIF orientation of photos taken by smartphones.
if image is not None:
- image = ImageOps.exif_transpose(image)
+ image = ImageOps.exif_transpose(image)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
diff --git a/modules/ui.py b/modules/ui.py
index 29954f2a..16982abf 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -854,6 +854,8 @@ def create_ui(wrap_gradio_gpu_call):
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
+ show_mask_alpha = cmd_opts.gradio_inpaint_tool == "color-sketch"
+ mask_alpha = gr.Slider(label="Mask transparency", interactive=show_mask_alpha, visible=show_mask_alpha)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
with gr.Row():
@@ -948,6 +950,7 @@ def create_ui(wrap_gradio_gpu_call):
steps,
sampler_index,
mask_blur,
+ mask_alpha,
inpainting_fill,
restore_faces,
tiling,
--
cgit v1.2.3
From 3b51d239ac9201228c6032fc109111e347e8e6b0 Mon Sep 17 00:00:00 2001
From: cluder <1590330+cluder@users.noreply.github.com>
Date: Wed, 9 Nov 2022 04:54:21 +0100
Subject: - do not use ckpt cache, if disabled - cache model after is has been
loaded from file
---
modules/sd_models.py | 27 +++++++++++++++++----------
1 file changed, 17 insertions(+), 10 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 34c57bfa..720c2a96 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -163,13 +163,21 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
- if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"):
+ cache_enabled = shared.opts.sd_checkpoint_cache > 0
+
+ if cache_enabled:
sd_vae.restore_base_vae(model)
- checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
- if checkpoint_info not in checkpoints_loaded:
+ if cache_enabled and checkpoint_info in checkpoints_loaded:
+ # use checkpoint cache
+ vae_name = sd_vae.get_filename(vae_file) if vae_file else None
+ vae_message = f" with {vae_name} VAE" if vae_name else ""
+ print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
+ model.load_state_dict(checkpoints_loaded[checkpoint_info])
+ else:
+ # load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
@@ -180,6 +188,10 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
del pl_sd
model.load_state_dict(sd, strict=False)
del sd
+
+ if cache_enabled:
+ # cache newly loaded model
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
@@ -199,13 +211,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.first_stage_model.to(devices.dtype_vae)
- else:
- vae_name = sd_vae.get_filename(vae_file) if vae_file else None
- vae_message = f" with {vae_name} VAE" if vae_name else ""
- print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
- model.load_state_dict(checkpoints_loaded[checkpoint_info])
-
- if shared.opts.sd_checkpoint_cache > 0:
+ # clean up cache if limit is reached
+ if cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
--
cgit v1.2.3
From eebf49592ad2c0933e58b06a098b92e48d47e4fe Mon Sep 17 00:00:00 2001
From: cluder <1590330+cluder@users.noreply.github.com>
Date: Wed, 9 Nov 2022 07:17:09 +0100
Subject: restore #4035 behavior
- if checkpoint cache is set to 1, keep 2 models in cache (current +1 more)
---
modules/sd_models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 720c2a96..80addf03 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -213,7 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
# clean up cache if limit is reached
if cache_enabled:
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
--
cgit v1.2.3
From 81f2575df91a50e4aa9ca816e02e3f77342eedc8 Mon Sep 17 00:00:00 2001
From: Liam
Date: Wed, 9 Nov 2022 15:24:31 -0500
Subject: updating the displayed generation info when user clicks images in the
gallery. feature request 4415
---
modules/ui.py | 20 ++++++++++++++++++++
1 file changed, 20 insertions(+)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 7ea1177f..756499d1 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -566,6 +566,17 @@ def apply_setting(key, value):
return value
+def update_generation_info(args):
+ generation_info, html_info, img_index = args
+ try:
+ generation_info = json.loads(generation_info)
+ return plaintext_to_html(generation_info["infotexts"][img_index])
+ except Exception:
+ pass
+ # if the json parse or anything else fails, just return the old html_info
+ return html_info
+
+
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
@@ -638,6 +649,15 @@ Requested path was: {f}
with gr.Group():
html_info = gr.HTML()
generation_info = gr.Textbox(visible=False)
+ if tabname == 'txt2img' or tabname == 'img2img':
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+ generation_info_button.click(
+ fn=update_generation_info,
+ _js="(x, y) => [x, y, selected_gallery_index()]",
+ inputs=[generation_info, html_info],
+ outputs=[html_info],
+ preprocess=False
+ )
save.click(
fn=wrap_gradio_call(save_files),
--
cgit v1.2.3
From 893191cab24cc3511135495d6d2c8d81f5ec63a3 Mon Sep 17 00:00:00 2001
From: Tong Zeng