From ad4de819c43997f2666b5bad95301f5c37f9018e Mon Sep 17 00:00:00 2001
From: victorca25
Date: Sun, 9 Oct 2022 13:02:12 +0200
Subject: update ESRGAN architecture and model to support all ESRGAN models in
the DB, BSRGAN and real-ESRGAN models
---
modules/bsrgan_model.py | 76 -------
modules/bsrgan_model_arch.py | 102 ----------
modules/esrgam_model_arch.py | 80 --------
modules/esrgan_model.py | 190 ++++++++++++------
modules/esrgan_model_arch.py | 463 +++++++++++++++++++++++++++++++++++++++++++
5 files changed, 591 insertions(+), 320 deletions(-)
delete mode 100644 modules/bsrgan_model.py
delete mode 100644 modules/bsrgan_model_arch.py
delete mode 100644 modules/esrgam_model_arch.py
create mode 100644 modules/esrgan_model_arch.py
(limited to 'modules')
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py
deleted file mode 100644
index 737e1a76..00000000
--- a/modules/bsrgan_model.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import os.path
-import sys
-import traceback
-
-import PIL.Image
-import numpy as np
-import torch
-from basicsr.utils.download_util import load_file_from_url
-
-import modules.upscaler
-from modules import devices, modelloader
-from modules.bsrgan_model_arch import RRDBNet
-
-
-class UpscalerBSRGAN(modules.upscaler.Upscaler):
- def __init__(self, dirname):
- self.name = "BSRGAN"
- self.model_name = "BSRGAN 4x"
- self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
- self.user_path = dirname
- super().__init__()
- model_paths = self.find_models(ext_filter=[".pt", ".pth"])
- scalers = []
- if len(model_paths) == 0:
- scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
- scalers.append(scaler_data)
- for file in model_paths:
- if "http" in file:
- name = self.model_name
- else:
- name = modelloader.friendly_name(file)
- try:
- scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
- scalers.append(scaler_data)
- except Exception:
- print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- self.scalers = scalers
-
- def do_upscale(self, img: PIL.Image, selected_file):
- torch.cuda.empty_cache()
- model = self.load_model(selected_file)
- if model is None:
- return img
- model.to(devices.device_bsrgan)
- torch.cuda.empty_cache()
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_bsrgan)
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- torch.cuda.empty_cache()
- return PIL.Image.fromarray(output, 'RGB')
-
- def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
- progress=True)
- else:
- filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
- return None
- model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
- model.load_state_dict(torch.load(filename), strict=True)
- model.eval()
- for k, v in model.named_parameters():
- v.requires_grad = False
- return model
-
diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py
deleted file mode 100644
index cb4d1c13..00000000
--- a/modules/bsrgan_model_arch.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import functools
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.nn.init as init
-
-
-def initialize_weights(net_l, scale=1):
- if not isinstance(net_l, list):
- net_l = [net_l]
- for net in net_l:
- for m in net.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
- m.weight.data *= scale # for residual block
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
- m.weight.data *= scale
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.BatchNorm2d):
- init.constant_(m.weight, 1)
- init.constant_(m.bias.data, 0.0)
-
-
-def make_layer(block, n_layers):
- layers = []
- for _ in range(n_layers):
- layers.append(block())
- return nn.Sequential(*layers)
-
-
-class ResidualDenseBlock_5C(nn.Module):
- def __init__(self, nf=64, gc=32, bias=True):
- super(ResidualDenseBlock_5C, self).__init__()
- # gc: growth channel, i.e. intermediate channels
- self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
- self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
- self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
- self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
- self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- # initialization
- initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
-
- def forward(self, x):
- x1 = self.lrelu(self.conv1(x))
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- return x5 * 0.2 + x
-
-
-class RRDB(nn.Module):
- '''Residual in Residual Dense Block'''
-
- def __init__(self, nf, gc=32):
- super(RRDB, self).__init__()
- self.RDB1 = ResidualDenseBlock_5C(nf, gc)
- self.RDB2 = ResidualDenseBlock_5C(nf, gc)
- self.RDB3 = ResidualDenseBlock_5C(nf, gc)
-
- def forward(self, x):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- return out * 0.2 + x
-
-
-class RRDBNet(nn.Module):
- def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
- super(RRDBNet, self).__init__()
- RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
- self.sf = sf
-
- self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
- self.RRDB_trunk = make_layer(RRDB_block_f, nb)
- self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- #### upsampling
- self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- if self.sf==4:
- self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
-
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- def forward(self, x):
- fea = self.conv_first(x)
- trunk = self.trunk_conv(self.RRDB_trunk(fea))
- fea = fea + trunk
-
- fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
- if self.sf==4:
- fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
- out = self.conv_last(self.lrelu(self.HRconv(fea)))
-
- return out
\ No newline at end of file
diff --git a/modules/esrgam_model_arch.py b/modules/esrgam_model_arch.py
deleted file mode 100644
index e413d36e..00000000
--- a/modules/esrgam_model_arch.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# this file is taken from https://github.com/xinntao/ESRGAN
-
-import functools
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-def make_layer(block, n_layers):
- layers = []
- for _ in range(n_layers):
- layers.append(block())
- return nn.Sequential(*layers)
-
-
-class ResidualDenseBlock_5C(nn.Module):
- def __init__(self, nf=64, gc=32, bias=True):
- super(ResidualDenseBlock_5C, self).__init__()
- # gc: growth channel, i.e. intermediate channels
- self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
- self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
- self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
- self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
- self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- # initialization
- # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
-
- def forward(self, x):
- x1 = self.lrelu(self.conv1(x))
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- return x5 * 0.2 + x
-
-
-class RRDB(nn.Module):
- '''Residual in Residual Dense Block'''
-
- def __init__(self, nf, gc=32):
- super(RRDB, self).__init__()
- self.RDB1 = ResidualDenseBlock_5C(nf, gc)
- self.RDB2 = ResidualDenseBlock_5C(nf, gc)
- self.RDB3 = ResidualDenseBlock_5C(nf, gc)
-
- def forward(self, x):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- return out * 0.2 + x
-
-
-class RRDBNet(nn.Module):
- def __init__(self, in_nc, out_nc, nf, nb, gc=32):
- super(RRDBNet, self).__init__()
- RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
-
- self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
- self.RRDB_trunk = make_layer(RRDB_block_f, nb)
- self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- #### upsampling
- self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
-
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- def forward(self, x):
- fea = self.conv_first(x)
- trunk = self.trunk_conv(self.RRDB_trunk(fea))
- fea = fea + trunk
-
- fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
- fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
- out = self.conv_last(self.lrelu(self.HRconv(fea)))
-
- return out
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 3970e6e4..a49e2258 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -5,68 +5,115 @@ import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
-import modules.esrgam_model_arch as arch
+import modules.esrgan_model_arch as arch
from modules import shared, modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-def fix_model_layers(crt_model, pretrained_net):
- # this code is adapted from https://github.com/xinntao/ESRGAN
- if 'conv_first.weight' in pretrained_net:
- return pretrained_net
- if 'model.0.weight' not in pretrained_net:
- is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
- if is_realesrgan:
- raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
- else:
- raise Exception("The file is not a ESRGAN model.")
+def mod2normal(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if 'conv_first.weight' in state_dict:
+ crt_net = {}
+ items = []
+ for k, v in state_dict.items():
+ items.append(k)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
+ crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
+ crt_net['model.3.weight'] = state_dict['upconv1.weight']
+ crt_net['model.3.bias'] = state_dict['upconv1.bias']
+ crt_net['model.6.weight'] = state_dict['upconv2.weight']
+ crt_net['model.6.bias'] = state_dict['upconv2.bias']
+ crt_net['model.8.weight'] = state_dict['HRconv.weight']
+ crt_net['model.8.bias'] = state_dict['HRconv.bias']
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
+ state_dict = crt_net
+ return state_dict
+
+
+def resrgan2normal(state_dict, nb=23):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
+ crt_net = {}
+ items = []
+ for k, v in state_dict.items():
+ items.append(k)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if "rdb" in k:
+ ori_k = k.replace('body.', 'model.1.sub.')
+ ori_k = ori_k.replace('.rdb', '.RDB')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
+ crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
+ crt_net['model.3.weight'] = state_dict['conv_up1.weight']
+ crt_net['model.3.bias'] = state_dict['conv_up1.bias']
+ crt_net['model.6.weight'] = state_dict['conv_up2.weight']
+ crt_net['model.6.bias'] = state_dict['conv_up2.bias']
+ crt_net['model.8.weight'] = state_dict['conv_hr.weight']
+ crt_net['model.8.bias'] = state_dict['conv_hr.bias']
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
+ state_dict = crt_net
+ return state_dict
+
+
+def infer_params(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ scale2x = 0
+ scalemin = 6
+ n_uplayer = 0
+ plus = False
+
+ for block in list(state_dict):
+ parts = block.split(".")
+ n_parts = len(parts)
+ if n_parts == 5 and parts[2] == "sub":
+ nb = int(parts[3])
+ elif n_parts == 3:
+ part_num = int(parts[1])
+ if (part_num > scalemin
+ and parts[0] == "model"
+ and parts[2] == "weight"):
+ scale2x += 1
+ if part_num > n_uplayer:
+ n_uplayer = part_num
+ out_nc = state_dict[block].shape[0]
+ if not plus and "conv1x1" in block:
+ plus = True
+
+ nf = state_dict["model.0.weight"].shape[0]
+ in_nc = state_dict["model.0.weight"].shape[1]
+ out_nc = out_nc
+ scale = 2 ** scale2x
+
+ return in_nc, out_nc, nf, nb, plus, scale
- crt_net = crt_model.state_dict()
- load_net_clean = {}
- for k, v in pretrained_net.items():
- if k.startswith('module.'):
- load_net_clean[k[7:]] = v
- else:
- load_net_clean[k] = v
- pretrained_net = load_net_clean
-
- tbd = []
- for k, v in crt_net.items():
- tbd.append(k)
-
- # directly copy
- for k, v in crt_net.items():
- if k in pretrained_net and pretrained_net[k].size() == v.size():
- crt_net[k] = pretrained_net[k]
- tbd.remove(k)
-
- crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
- crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
-
- for k in tbd.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[k] = pretrained_net[ori_k]
- tbd.remove(k)
-
- crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
- crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
- crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
- crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
- crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
- crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
- crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
- crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
- crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
- crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
-
- return crt_net
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
@@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
- pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
- crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
+ state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
+
+ if "params_ema" in state_dict:
+ state_dict = state_dict["params_ema"]
+ elif "params" in state_dict:
+ state_dict = state_dict["params"]
+ num_conv = 16 if "realesr-animevideov3" in filename else 32
+ model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
+ nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
+ state_dict = resrgan2normal(state_dict, nb)
+ elif "conv_first.weight" in state_dict:
+ state_dict = mod2normal(state_dict)
+ elif "model.0.weight" not in state_dict:
+ raise Exception("The file is not a recognized ESRGAN model.")
+
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
- pretrained_net = fix_model_layers(crt_model, pretrained_net)
- crt_model.load_state_dict(pretrained_net)
- crt_model.eval()
+ model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
+ model.load_state_dict(state_dict)
+ model.eval()
- return crt_model
+ return model
def upscale_without_tiling(model, img):
img = np.array(img)
img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
new file mode 100644
index 00000000..bc9ceb2a
--- /dev/null
+++ b/modules/esrgan_model_arch.py
@@ -0,0 +1,463 @@
+# this file is adapted from https://github.com/victorca25/iNNfer
+
+import math
+import functools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+####################
+# RRDBNet Generator
+####################
+
+class RRDBNet(nn.Module):
+ def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
+ act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
+ finalact=None, gaussian_noise=False, plus=False):
+ super(RRDBNet, self).__init__()
+ n_upscale = int(math.log(upscale, 2))
+ if upscale == 3:
+ n_upscale = 1
+
+ self.resrgan_scale = 0
+ if in_nc % 16 == 0:
+ self.resrgan_scale = 1
+ elif in_nc != 4 and in_nc % 4 == 0:
+ self.resrgan_scale = 2
+
+ fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+ rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
+ LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
+
+ if upsample_mode == 'upconv':
+ upsample_block = upconv_block
+ elif upsample_mode == 'pixelshuffle':
+ upsample_block = pixelshuffle_block
+ else:
+ raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
+ if upscale == 3:
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
+ else:
+ upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
+ HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
+ HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+
+ outact = act(finalact) if finalact else None
+
+ self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
+ *upsampler, HR_conv0, HR_conv1, outact)
+
+ def forward(self, x, outm=None):
+ if self.resrgan_scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ elif self.resrgan_scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ else:
+ feat = x
+
+ return self.model(feat)
+
+
+class RRDB(nn.Module):
+ """
+ Residual in Residual Dense Block
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
+ """
+
+ def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
+ super(RRDB, self).__init__()
+ # This is for backwards compatibility with existing models
+ if nr == 3:
+ self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ else:
+ RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
+ self.RDBs = nn.Sequential(*RDB_list)
+
+ def forward(self, x):
+ if hasattr(self, 'RDB1'):
+ out = self.RDB1(x)
+ out = self.RDB2(out)
+ out = self.RDB3(out)
+ else:
+ out = self.RDBs(x)
+ return out * 0.2 + x
+
+
+class ResidualDenseBlock_5C(nn.Module):
+ """
+ Residual Dense Block
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
+ Modified options that can be used:
+ - "Partial Convolution based Padding" arXiv:1811.11718
+ - "Spectral normalization" arXiv:1802.05957
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
+ {Rakotonirina} and A. {Rasoanaivo}
+ """
+
+ def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
+ super(ResidualDenseBlock_5C, self).__init__()
+
+ self.noise = GaussianNoise() if gaussian_noise else None
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
+
+ self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ if mode == 'CNA':
+ last_act = None
+ else:
+ last_act = act_type
+ self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+
+ def forward(self, x):
+ x1 = self.conv1(x)
+ x2 = self.conv2(torch.cat((x, x1), 1))
+ if self.conv1x1:
+ x2 = x2 + self.conv1x1(x)
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
+ if self.conv1x1:
+ x4 = x4 + x2
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ if self.noise:
+ return self.noise(x5.mul(0.2) + x)
+ else:
+ return x5 * 0.2 + x
+
+
+####################
+# ESRGANplus
+####################
+
+class GaussianNoise(nn.Module):
+ def __init__(self, sigma=0.1, is_relative_detach=False):
+ super().__init__()
+ self.sigma = sigma
+ self.is_relative_detach = is_relative_detach
+ self.noise = torch.tensor(0, dtype=torch.float)
+
+ def forward(self, x):
+ if self.training and self.sigma != 0:
+ self.noise = self.noise.to(x.device)
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
+ x = x + sampled_noise
+ return x
+
+def conv1x1(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+####################
+# SRVGGNetCompact
+####################
+
+class SRVGGNetCompact(nn.Module):
+ """A compact VGG-style network structure for super-resolution.
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+ super(SRVGGNetCompact, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.num_out_ch = num_out_ch
+ self.num_feat = num_feat
+ self.num_conv = num_conv
+ self.upscale = upscale
+ self.act_type = act_type
+
+ self.body = nn.ModuleList()
+ # the first conv
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+ # the first activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the body structure
+ for _ in range(num_conv):
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+ # activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the last conv
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+ # upsample
+ self.upsampler = nn.PixelShuffle(upscale)
+
+ def forward(self, x):
+ out = x
+ for i in range(0, len(self.body)):
+ out = self.body[i](out)
+
+ out = self.upsampler(out)
+ # add the nearest upsampled image, so that the network learns the residual
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+ out += base
+ return out
+
+
+####################
+# Upsampler
+####################
+
+class Upsample(nn.Module):
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
+ The input data is assumed to be of the form
+ `minibatch x channels x [optional depth] x [optional height] x width`.
+ """
+
+ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ super(Upsample, self).__init__()
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
+ self.mode = mode
+ self.size = size
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
+
+ def extra_repr(self):
+ if self.scale_factor is not None:
+ info = 'scale_factor=' + str(self.scale_factor)
+ else:
+ info = 'size=' + str(self.size)
+ info += ', mode=' + self.mode
+ return info
+
+
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
+ """
+ Pixel shuffle layer
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
+ Neural Network, CVPR17)
+ """
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
+
+ n = norm(norm_type, out_nc) if norm_type else None
+ a = act(act_type) if act_type else None
+ return sequential(conv, pixel_shuffle, n, a)
+
+
+def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
+ """ Upconv layer """
+ upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
+ return sequential(upsample, conv)
+
+
+
+
+
+
+
+
+####################
+# Basic blocks
+####################
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+ Args:
+ basic_block (nn.module): nn.module class for basic block. (block)
+ num_basic_block (int): number of blocks. (n_layers)
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
+ """ activation helper """
+ act_type = act_type.lower()
+ if act_type == 'relu':
+ layer = nn.ReLU(inplace)
+ elif act_type in ('leakyrelu', 'lrelu'):
+ layer = nn.LeakyReLU(neg_slope, inplace)
+ elif act_type == 'prelu':
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
+ elif act_type == 'tanh': # [-1, 1] range output
+ layer = nn.Tanh()
+ elif act_type == 'sigmoid': # [0, 1] range output
+ layer = nn.Sigmoid()
+ else:
+ raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
+ return layer
+
+
+class Identity(nn.Module):
+ def __init__(self, *kwargs):
+ super(Identity, self).__init__()
+
+ def forward(self, x, *kwargs):
+ return x
+
+
+def norm(norm_type, nc):
+ """ Return a normalization layer """
+ norm_type = norm_type.lower()
+ if norm_type == 'batch':
+ layer = nn.BatchNorm2d(nc, affine=True)
+ elif norm_type == 'instance':
+ layer = nn.InstanceNorm2d(nc, affine=False)
+ elif norm_type == 'none':
+ def norm_layer(x): return Identity()
+ else:
+ raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
+ return layer
+
+
+def pad(pad_type, padding):
+ """ padding layer helper """
+ pad_type = pad_type.lower()
+ if padding == 0:
+ return None
+ if pad_type == 'reflect':
+ layer = nn.ReflectionPad2d(padding)
+ elif pad_type == 'replicate':
+ layer = nn.ReplicationPad2d(padding)
+ elif pad_type == 'zero':
+ layer = nn.ZeroPad2d(padding)
+ else:
+ raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
+ return layer
+
+
+def get_valid_padding(kernel_size, dilation):
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
+ padding = (kernel_size - 1) // 2
+ return padding
+
+
+class ShortcutBlock(nn.Module):
+ """ Elementwise sum the output of a submodule to its input """
+ def __init__(self, submodule):
+ super(ShortcutBlock, self).__init__()
+ self.sub = submodule
+
+ def forward(self, x):
+ output = x + self.sub(x)
+ return output
+
+ def __repr__(self):
+ return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
+
+
+def sequential(*args):
+ """ Flatten Sequential. It unwraps nn.Sequential. """
+ if len(args) == 1:
+ if isinstance(args[0], OrderedDict):
+ raise NotImplementedError('sequential does not support OrderedDict input.')
+ return args[0] # No sequential is needed.
+ modules = []
+ for module in args:
+ if isinstance(module, nn.Sequential):
+ for submodule in module.children():
+ modules.append(submodule)
+ elif isinstance(module, nn.Module):
+ modules.append(module)
+ return nn.Sequential(*modules)
+
+
+def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False):
+ """ Conv layer with padding, normalization, activation """
+ assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
+ padding = get_valid_padding(kernel_size, dilation)
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
+ padding = padding if pad_type == 'zero' else 0
+
+ if convtype=='PartialConv2D':
+ c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='DeformConv2D':
+ c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='Conv3D':
+ c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ else:
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+
+ if spectral_norm:
+ c = nn.utils.spectral_norm(c)
+
+ a = act(act_type) if act_type else None
+ if 'CNA' in mode:
+ n = norm(norm_type, out_nc) if norm_type else None
+ return sequential(p, c, n, a)
+ elif mode == 'NAC':
+ if norm_type is None and act_type is not None:
+ a = act(act_type, inplace=False)
+ n = norm(norm_type, in_nc) if norm_type else None
+ return sequential(n, a, p, c)
--
cgit v1.2.3
From bb57f30c2de46cfca5419ad01738a41705f96cc3 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Fri, 14 Oct 2022 10:56:41 +0200
Subject: init
---
modules/processing.py | 17 +++++-
modules/sd_hijack.py | 80 +++++++++++++++++++++++++-
modules/shared.py | 5 ++
modules/textual_inversion/dataset.py | 2 +-
modules/textual_inversion/textual_inversion.py | 35 +++++++----
modules/txt2img.py | 11 +++-
modules/ui.py | 59 ++++++++++++-------
7 files changed, 171 insertions(+), 38 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index d5172f00..9a033759 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -316,11 +316,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
-def process_images(p: StableDiffusionProcessing) -> Processed:
+def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0,
+ aesthetic_imgs=None,aesthetic_slerp=False) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
+ aesthetic_lr = float(aesthetic_lr)
+ aesthetic_weight = float(aesthetic_weight)
+ aesthetic_steps = int(aesthetic_steps)
+
if type(p.prompt) == list:
- assert(len(p.prompt) > 0)
+ assert (len(p.prompt) > 0)
else:
assert p.prompt is not None
@@ -394,7 +399,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
- uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
+ if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
+ shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0)
+ uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt],
+ p.steps)
+ if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
+ shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight,
+ aesthetic_steps, aesthetic_imgs,aesthetic_slerp)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index c81722a0..6d5196fe 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -9,11 +9,14 @@ from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
-from modules.shared import opts, device, cmd_opts
+from modules.shared import opts, device, cmd_opts, aesthetic_embeddings
from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+from transformers import CLIPVisionModel, CLIPModel
+import torch.optim as optim
+import copy
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
@@ -109,13 +112,29 @@ class StableDiffusionModelHijack:
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
+def slerp(low, high, val):
+ low_norm = low/torch.norm(low, dim=1, keepdim=True)
+ high_norm = high/torch.norm(high, dim=1, keepdim=True)
+ omega = torch.acos((low_norm*high_norm).sum(1))
+ so = torch.sin(omega)
+ res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+ return res
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
+ self.clipModel = CLIPModel.from_pretrained(
+ self.wrapped.transformer.name_or_path
+ )
+ del self.clipModel.vision_model
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
+ # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval()
+ self.image_embs_name = None
+ self.image_embs = None
+ self.load_image_embs(None)
+
self.token_mults = {}
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0]
@@ -136,6 +155,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
+ def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None,
+ aesthetic_slerp=True):
+ self.slerp = aesthetic_slerp
+ self.aesthetic_lr = aesthetic_lr
+ self.aesthetic_weight = aesthetic_weight
+ self.aesthetic_steps = aesthetic_steps
+ self.load_image_embs(image_embs_name)
+
+ def load_image_embs(self, image_embs_name):
+ if image_embs_name is None or len(image_embs_name) == 0:
+ image_embs_name = None
+ if image_embs_name is not None and self.image_embs_name != image_embs_name:
+ self.image_embs_name = image_embs_name
+ self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device)
+ self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
+ self.image_embs.requires_grad_(False)
+
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_end = self.wrapped.tokenizer.eos_token_id
@@ -333,7 +369,47 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
+
+ if len(text[
+ 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None:
+ if not opts.use_old_emphasis_implementation:
+ remade_batch_tokens = [
+ [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
+ remade_batch_tokens]
+
+ tokens = torch.asarray(remade_batch_tokens).to(device)
+ with torch.enable_grad():
+ model = copy.deepcopy(self.clipModel).to(device)
+ model.requires_grad_(True)
+
+ # We optimize the model to maximize the similarity
+ optimizer = optim.Adam(
+ model.text_model.parameters(), lr=self.aesthetic_lr
+ )
+
+ for i in range(self.aesthetic_steps):
+ text_embs = model.get_text_features(input_ids=tokens)
+ text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
+ sim = text_embs @ self.image_embs.T
+ loss = -sim
+ optimizer.zero_grad()
+ loss.mean().backward()
+ optimizer.step()
+
+ zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+ if opts.CLIP_stop_at_last_layers > 1:
+ zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
+ zn = model.text_model.final_layer_norm(zn)
+ else:
+ zn = zn.last_hidden_state
+ model.cpu()
+ del model
+
+ if self.slerp:
+ z = slerp(z, zn, self.aesthetic_weight)
+ else:
+ z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
+
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
diff --git a/modules/shared.py b/modules/shared.py
index 5901e605..cf13a10d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -30,6 +30,8 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(script_path, 'aesthetic_embeddings'),
+ help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
@@ -90,6 +92,9 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
+aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
+ os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
+
def reload_hypernetworks():
global hypernetworks
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 67e90afe..59b2b021 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -48,7 +48,7 @@ class PersonalizedBase(Dataset):
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
try:
- image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC)
except Exception:
continue
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fa0e33a2..b12a8e6d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -172,7 +172,15 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
+def batched(dataset, total, n=1):
+ for ndx in range(0, total, n):
+ yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
+
+
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps,
+ create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding,
+ preview_image_prompt, batch_size=1,
+ gradient_accumulation=1):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -204,7 +212,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
+ height=training_height,
+ repeats=shared.opts.training_image_repeats_per_epoch,
+ placeholder_token=embedding_name, model=shared.sd_model,
+ device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -223,7 +235,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
- pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ pbar = tqdm.tqdm(enumerate(batched(ds, steps - ititial_step, batch_size)), total=steps - ititial_step)
for i, entry in pbar:
embedding.step = i + ititial_step
@@ -235,17 +247,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
break
with torch.autocast("cuda"):
- c = cond_model([entry.cond_text])
+ c = cond_model([e.cond_text for e in entry])
+
+ x = torch.stack([e.latent for e in entry]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
- x = entry.latent.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
- optimizer.zero_grad()
loss.backward()
- optimizer.step()
+ if ((i + 1) % gradient_accumulation == 0) or (i + 1 == steps - ititial_step):
+ optimizer.step()
+ optimizer.zero_grad()
+
epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
@@ -259,7 +274,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
- preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry[0].cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
@@ -305,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entry[-1].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
diff --git a/modules/txt2img.py b/modules/txt2img.py
index e985242b..78342024 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -6,7 +6,14 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int,
+ restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int,
+ subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool,
+ height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float,
+ aesthetic_lr=0,
+ aesthetic_weight=0, aesthetic_steps=0,
+ aesthetic_imgs=None,
+ aesthetic_slerp=False, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -40,7 +47,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
- processed = process_images(p)
+ processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp)
shared.total_tqdm.clear()
diff --git a/modules/ui.py b/modules/ui.py
index 220fb80b..d961d126 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -24,7 +24,8 @@ import gradio.routes
from modules import sd_hijack
from modules.paths import script_path
-from modules.shared import opts, cmd_opts
+from modules.shared import opts, cmd_opts,aesthetic_embeddings
+
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
import modules.shared as shared
@@ -534,6 +535,14 @@ def create_ui(wrap_gradio_gpu_call):
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ with gr.Group():
+ aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7)
+ aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50)
+
+ aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
+ aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
+
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False)
@@ -586,25 +595,30 @@ def create_ui(wrap_gradio_gpu_call):
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
_js="submit",
inputs=[
- txt2img_prompt,
- txt2img_negative_prompt,
- txt2img_prompt_style,
- txt2img_prompt_style2,
- steps,
- sampler_index,
- restore_faces,
- tiling,
- batch_count,
- batch_size,
- cfg_scale,
- seed,
- subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
- height,
- width,
- enable_hr,
- scale_latent,
- denoising_strength,
- ] + custom_inputs,
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ txt2img_prompt_style,
+ txt2img_prompt_style2,
+ steps,
+ sampler_index,
+ restore_faces,
+ tiling,
+ batch_count,
+ batch_size,
+ cfg_scale,
+ seed,
+ subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
+ height,
+ width,
+ enable_hr,
+ scale_latent,
+ denoising_strength,
+ aesthetic_lr,
+ aesthetic_weight,
+ aesthetic_steps,
+ aesthetic_imgs,
+ aesthetic_slerp
+ ] + custom_inputs,
outputs=[
txt2img_gallery,
generation_info,
@@ -1097,6 +1111,9 @@ def create_ui(wrap_gradio_gpu_call):
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ batch_size = gr.Slider(minimum=1, maximum=64, step=1, label="Batch Size", value=4)
+ gradient_accumulation = gr.Slider(minimum=1, maximum=256, step=1, label="Gradient accumulation",
+ value=1)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
@@ -1180,6 +1197,8 @@ def create_ui(wrap_gradio_gpu_call):
template_file,
save_image_with_stored_embedding,
preview_image_prompt,
+ batch_size,
+ gradient_accumulation
],
outputs=[
ti_output,
--
cgit v1.2.3
From 37d7ffb415cd8c69b3c0bb5f61844dde0b169f78 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 15:59:37 +0200
Subject: fix to tokens lenght, addend embs generator, add new features to edit
the embedding before the generation using text
---
modules/aesthetic_clip.py | 78 ++++++++++++++++++++++++
modules/processing.py | 148 +++++++++++++++++++++++++++++++---------------
modules/sd_hijack.py | 111 ++++++++++++++++++++++------------
modules/shared.py | 4 ++
modules/txt2img.py | 10 +++-
modules/ui.py | 47 ++++++++++++---
6 files changed, 302 insertions(+), 96 deletions(-)
create mode 100644 modules/aesthetic_clip.py
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
new file mode 100644
index 00000000..f15cfd47
--- /dev/null
+++ b/modules/aesthetic_clip.py
@@ -0,0 +1,78 @@
+import itertools
+import os
+from pathlib import Path
+import html
+import gc
+
+import gradio as gr
+import torch
+from PIL import Image
+from modules import shared
+from modules.shared import device, aesthetic_embeddings
+from transformers import CLIPModel, CLIPProcessor
+
+from tqdm.auto import tqdm
+
+
+def get_all_images_in_folder(folder):
+ return [os.path.join(folder, f) for f in os.listdir(folder) if
+ os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
+
+
+def check_is_valid_image_file(filename):
+ return filename.lower().endswith(('.png', '.jpg', '.jpeg'))
+
+
+def batched(dataset, total, n=1):
+ for ndx in range(0, total, n):
+ yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
+
+
+def iter_to_batched(iterable, n=1):
+ it = iter(iterable)
+ while True:
+ chunk = tuple(itertools.islice(it, n))
+ if not chunk:
+ return
+ yield chunk
+
+
+def generate_imgs_embd(name, folder, batch_size):
+ # clipModel = CLIPModel.from_pretrained(
+ # shared.sd_model.cond_stage_model.clipModel.name_or_path
+ # )
+ model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device)
+ processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path)
+
+ with torch.no_grad():
+ embs = []
+ for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
+ desc=f"Generating embeddings for {name}"):
+ if shared.state.interrupted:
+ break
+ inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
+ outputs = model.get_image_features(**inputs).cpu()
+ embs.append(torch.clone(outputs))
+ inputs.to("cpu")
+ del inputs, outputs
+
+ embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
+
+ # The generated embedding will be located here
+ path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
+ torch.save(embs, path)
+
+ model = model.cpu()
+ del model
+ del processor
+ del embs
+ gc.collect()
+ torch.cuda.empty_cache()
+ res = f"""
+ Done generating embedding for {name}!
+ Hypernetwork saved to {html.escape(path)}
+ """
+ shared.update_aesthetic_embeddings()
+ return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
+ value=sorted(aesthetic_embeddings.keys())[0] if len(
+ aesthetic_embeddings) > 0 else None), res, ""
diff --git a/modules/processing.py b/modules/processing.py
index 9a033759..ab68d63a 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -20,7 +20,6 @@ import modules.images as images
import modules.styles
import logging
-
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
@@ -52,8 +51,13 @@ def get_correct_sampler(p):
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img
+
class StableDiffusionProcessing:
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1,
+ subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True,
+ sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512,
+ restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False,
+ extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -104,7 +108,8 @@ class StableDiffusionProcessing:
class Processed:
- def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None,
+ all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
@@ -141,7 +146,8 @@ class Processed:
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
- self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
+ self.subseed = int(
+ self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed]
@@ -181,39 +187,43 @@ class Processed:
return json.dumps(obj)
- def infotext(self, p: StableDiffusionProcessing, index):
- return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
+ def infotext(self, p: StableDiffusionProcessing, index):
+ return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[],
+ position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
- low_norm = low/torch.norm(low, dim=1, keepdim=True)
- high_norm = high/torch.norm(high, dim=1, keepdim=True)
- dot = (low_norm*high_norm).sum(1)
+ low_norm = low / torch.norm(low, dim=1, keepdim=True)
+ high_norm = high / torch.norm(high, dim=1, keepdim=True)
+ dot = (low_norm * high_norm).sum(1)
if dot.mean() > 0.9995:
return low * val + high * (1 - val)
omega = torch.acos(dot)
so = torch.sin(omega)
- res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+ res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
-def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
+def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0,
+ p=None):
xs = []
# if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
+ if p is not None and p.sampler is not None and (
+ len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
for i, seed in enumerate(seeds):
- noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
+ noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (
+ shape[0], seed_resize_from_h // 8, seed_resize_from_w // 8)
subnoise = None
if subseeds is not None:
@@ -241,7 +251,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
dx = max(-dx, 0)
dy = max(-dy, 0)
- x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
+ x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
noise = x
if sampler_noises is not None:
@@ -293,14 +303,20 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
- "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
- "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
+ "Model hash": getattr(p, 'sd_model_hash',
+ None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
+ "Model": (
+ None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(
+ ',', '').replace(':', '')),
+ "Hypernet": (
+ None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(
+ ':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
- "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
+ "Seed resize from": (
+ None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
@@ -309,7 +325,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params)
- generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
+ generation_params_text = ", ".join(
+ [k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
@@ -317,7 +334,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0,
- aesthetic_imgs=None,aesthetic_slerp=False) -> Processed:
+ aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="",
+ aesthetic_slerp_angle=0.15,
+ aesthetic_text_negative=False) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
aesthetic_lr = float(aesthetic_lr)
@@ -385,7 +404,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
for n in range(p.n_iter):
if state.skipped:
state.skipped = False
-
+
if state.interrupted:
break
@@ -396,16 +415,19 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if (len(prompts) == 0):
break
- #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
- #c = p.sd_model.get_learned_conditioning(prompts)
+ # uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
+ # c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
- shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0)
+ shared.sd_model.cond_stage_model.set_aesthetic_params()
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt],
p.steps)
if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight,
- aesthetic_steps, aesthetic_imgs,aesthetic_slerp)
+ aesthetic_steps, aesthetic_imgs,
+ aesthetic_slerp, aesthetic_imgs_text,
+ aesthetic_slerp_angle,
+ aesthetic_text_negative)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
@@ -413,13 +435,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
comments[comment] = 1
if p.n_iter > 1:
- shared.state.job = f"Batch {n+1} out of {p.n_iter}"
+ shared.state.job = f"Batch {n + 1} out of {p.n_iter}"
with devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds,
+ subseed_strength=p.subseed_strength)
if state.interrupted or state.skipped:
-
# if we are interrupted, sample returns just noise
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
@@ -445,7 +467,9 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
- images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i],
+ opts.samples_format, info=infotext(n, i), p=p,
+ suffix="-before-face-restoration")
devices.torch_gc()
@@ -456,7 +480,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format,
+ info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
if p.overlay_images is not None and i < len(p.overlay_images):
@@ -474,7 +499,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
image = image.convert('RGB')
if opts.samples_save and not p.do_not_save_samples:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format,
+ info=infotext(n, i), p=p)
text = infotext(n, i)
infotexts.append(text)
@@ -482,7 +508,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
image.info["parameters"] = text
output_images.append(image)
- del x_samples_ddim
+ del x_samples_ddim
devices.torch_gc()
@@ -504,10 +530,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format,
+ info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+ return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]),
+ subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds,
+ index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
@@ -543,25 +572,34 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds,
+ subseeds=subseeds, subseed_strength=self.subseed_strength,
+ seed_resize_from_h=self.seed_resize_from_h,
+ seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
return samples
- x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds,
+ subseeds=subseeds, subseed_strength=self.subseed_strength,
+ seed_resize_from_h=self.seed_resize_from_h,
+ seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
- samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
+ samples = samples[:, :, truncate_y // 2:samples.shape[2] - truncate_y // 2,
+ truncate_x // 2:samples.shape[3] - truncate_x // 2]
if self.scale_latent:
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f),
+ mode="bilinear")
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
- decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
+ decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width),
+ mode="bilinear")
else:
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@@ -585,13 +623,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
- noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds,
+ subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h,
+ seed_resize_from_w=self.seed_resize_from_w, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning,
+ steps=self.steps)
return samples
@@ -599,7 +640,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
+ def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4,
+ inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0,
+ **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@@ -607,7 +650,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.denoising_strength: float = denoising_strength
self.init_latent = None
self.image_mask = mask
- #self.image_unblurred_mask = None
+ # self.image_unblurred_mask = None
self.latent_mask = None
self.mask_for_overlay = None
self.mask_blur = mask_blur
@@ -619,7 +662,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = None
def init(self, all_prompts, all_seeds, all_subseeds):
- self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index,
+ self.sd_model)
crop_region = None
if self.image_mask is not None:
@@ -628,7 +672,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_mask_invert:
self.image_mask = ImageOps.invert(self.image_mask)
- #self.image_unblurred_mask = self.image_mask
+ # self.image_unblurred_mask = self.image_mask
if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
@@ -642,7 +686,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask = mask.crop(crop_region)
self.image_mask = images.resize_image(2, mask, self.width, self.height)
- self.paste_to = (x1, y1, x2-x1, y2-y1)
+ self.paste_to = (x1, y1, x2 - x1, y2 - y1)
else:
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
np_mask = np.array(self.image_mask)
@@ -665,7 +709,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.image_mask is not None:
image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
+ image_masked.paste(image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image_masked.convert('RGBA'))
@@ -714,12 +759,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# this needs to be fixed to be done in sample() using actual seeds for batches
if self.inpainting_fill == 2:
- self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:],
+ all_seeds[
+ 0:self.init_latent.shape[
+ 0]]) * self.nmask
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds,
+ subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h,
+ seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 6d5196fe..192883b2 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -14,7 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
-from transformers import CLIPVisionModel, CLIPModel
+from tqdm import trange
+from transformers import CLIPVisionModel, CLIPModel, CLIPTokenizer
import torch.optim as optim
import copy
@@ -22,21 +23,25 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (
+ 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
+ elif not cmd_opts.disable_opt_split_attention and (
+ cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps':
- print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
+ print(
+ "The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
else:
@@ -112,14 +117,16 @@ class StableDiffusionModelHijack:
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
+
def slerp(low, high, val):
- low_norm = low/torch.norm(low, dim=1, keepdim=True)
- high_norm = high/torch.norm(high, dim=1, keepdim=True)
- omega = torch.acos((low_norm*high_norm).sum(1))
+ low_norm = low / torch.norm(low, dim=1, keepdim=True)
+ high_norm = high / torch.norm(high, dim=1, keepdim=True)
+ omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
- res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+ res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
+
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
@@ -128,6 +135,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.wrapped.transformer.name_or_path
)
del self.clipModel.vision_model
+ self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path)
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
# self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval()
@@ -139,7 +147,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0]
- tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if
+ '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
@@ -155,8 +164,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
- def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None,
- aesthetic_slerp=True):
+ def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
+ aesthetic_slerp=True, aesthetic_imgs_text="",
+ aesthetic_slerp_angle=0.15,
+ aesthetic_text_negative=False):
+ self.aesthetic_imgs_text = aesthetic_imgs_text
+ self.aesthetic_slerp_angle = aesthetic_slerp_angle
+ self.aesthetic_text_negative = aesthetic_text_negative
self.slerp = aesthetic_slerp
self.aesthetic_lr = aesthetic_lr
self.aesthetic_weight = aesthetic_weight
@@ -180,7 +194,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else:
parsed = [[line, 1.0]]
- tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
+ tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)[
+ "input_ids"]
fixes = []
remade_tokens = []
@@ -196,18 +211,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if token == self.comma_token:
last_comma = len(remade_tokens)
- elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
+ elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens),
+ 1) % 75 == 0 and last_comma != -1 and len(
+ remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
-
+
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
-
+
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
@@ -248,7 +265,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
- remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms,
+ hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
@@ -259,7 +277,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
@@ -289,7 +306,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens,
+ i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
@@ -312,11 +330,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+ hijack_comments.append(
+ f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
@@ -326,23 +345,26 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
+
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(
+ text)
else:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(
+ text)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
+ self.hijack.comments.append(
+ "Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
-
+
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
@@ -356,7 +378,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
-
+
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
@@ -378,19 +400,30 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens).to(device)
+
+ model = copy.deepcopy(self.clipModel).to(device)
+ model.requires_grad_(True)
+ if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
+ text_embs_2 = model.get_text_features(
+ **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
+ if self.aesthetic_text_negative:
+ text_embs_2 = self.image_embs - text_embs_2
+ text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
+ img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
+ else:
+ img_embs = self.image_embs
+
with torch.enable_grad():
- model = copy.deepcopy(self.clipModel).to(device)
- model.requires_grad_(True)
# We optimize the model to maximize the similarity
optimizer = optim.Adam(
model.text_model.parameters(), lr=self.aesthetic_lr
)
- for i in range(self.aesthetic_steps):
+ for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
text_embs = model.get_text_features(input_ids=tokens)
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
- sim = text_embs @ self.image_embs.T
+ sim = text_embs @ img_embs.T
loss = -sim
optimizer.zero_grad()
loss.mean().backward()
@@ -405,6 +438,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
model.cpu()
del model
+ zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1)
if self.slerp:
z = slerp(z, zn, self.aesthetic_weight)
else:
@@ -413,15 +447,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
-
+
return z
-
-
+
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
+ remade_batch_tokens = [
+ [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
+ remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
-
+
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
@@ -461,8 +496,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
- emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
- tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor)
diff --git a/modules/shared.py b/modules/shared.py
index cf13a10d..7cd608ca 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -95,6 +95,10 @@ loaded_hypernetwork = None
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
+def update_aesthetic_embeddings():
+ global aesthetic_embeddings
+ aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
+ os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
def reload_hypernetworks():
global hypernetworks
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 78342024..eedcdfe0 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -13,7 +13,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
aesthetic_lr=0,
aesthetic_weight=0, aesthetic_steps=0,
aesthetic_imgs=None,
- aesthetic_slerp=False, *args):
+ aesthetic_slerp=False,
+ aesthetic_imgs_text="",
+ aesthetic_slerp_angle=0.15,
+ aesthetic_text_negative=False,
+ *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -47,7 +51,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
- processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp)
+ processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text,
+ aesthetic_slerp_angle,
+ aesthetic_text_negative)
shared.total_tqdm.clear()
diff --git a/modules/ui.py b/modules/ui.py
index d961d126..e98e2113 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -41,6 +41,7 @@ from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui
+import modules.aesthetic_clip
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -449,7 +450,7 @@ def create_toprow(is_img2img):
with gr.Row():
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
with gr.Column(scale=1, elem_id="roll_col"):
- sh = gr.Button(elem_id="sh", visible=True)
+ sh = gr.Button(elem_id="sh", visible=True)
with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
@@ -536,9 +537,13 @@ def create_ui(wrap_gradio_gpu_call):
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Group():
- aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
- aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7)
- aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50)
+ aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001")
+ aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
+ aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5)
+ with gr.Row():
+ aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
+ aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
+ aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
@@ -617,7 +622,10 @@ def create_ui(wrap_gradio_gpu_call):
aesthetic_weight,
aesthetic_steps,
aesthetic_imgs,
- aesthetic_slerp
+ aesthetic_slerp,
+ aesthetic_imgs_text,
+ aesthetic_slerp_angle,
+ aesthetic_text_negative
] + custom_inputs,
outputs=[
txt2img_gallery,
@@ -721,7 +729,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
- inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32)
+ inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=1024, step=4, value=32)
with gr.TabItem('Batch img2img', id='batch'):
hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
@@ -1071,6 +1079,17 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary')
+ with gr.Tab(label="Create images embedding"):
+ new_embedding_name_ae = gr.Textbox(label="Name")
+ process_src_ae = gr.Textbox(label='Source directory')
+ batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')
+
with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
@@ -1139,7 +1158,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
new_embedding_name,
- initialization_text,
+ process_src,
nvpt,
],
outputs=[
@@ -1149,6 +1168,20 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ create_embedding_ae.click(
+ fn=modules.aesthetic_clip.generate_imgs_embd,
+ inputs=[
+ new_embedding_name_ae,
+ process_src_ae,
+ batch_ae
+ ],
+ outputs=[
+ aesthetic_imgs,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
create_hypernetwork.click(
fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[
--
cgit v1.2.3
From 4387e4fe6479c08f7bc7e42924c3a1093e3a1872 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 18:39:29 +0200
Subject: Update modules/ui.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: VÃctor Gallego
---
modules/ui.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index d0696101..5bb961b2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -599,7 +599,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001")
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
- aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5)
+ aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
+
with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
--
cgit v1.2.3
From 9b7705e0573bddde26df4575c71f994d73a4d519 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 18:40:34 +0200
Subject: Update modules/aesthetic_clip.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: VÃctor Gallego
---
modules/aesthetic_clip.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
index f15cfd47..bcf2b073 100644
--- a/modules/aesthetic_clip.py
+++ b/modules/aesthetic_clip.py
@@ -70,7 +70,7 @@ def generate_imgs_embd(name, folder, batch_size):
torch.cuda.empty_cache()
res = f"""
Done generating embedding for {name}!
- Hypernetwork saved to {html.escape(path)}
+ Aesthetic embedding saved to {html.escape(path)}
"""
shared.update_aesthetic_embeddings()
return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
--
cgit v1.2.3
From 0d4f5db235357aeb4c7a8738179ba33aaf5a6b75 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 18:40:58 +0200
Subject: Update modules/ui.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: VÃctor Gallego
---
modules/ui.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 5bb961b2..25eba548 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -597,7 +597,8 @@ def create_ui(wrap_gradio_gpu_call):
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Group():
- aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001")
+ aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001")
+
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
--
cgit v1.2.3
From ad9bc604a8fadcfebe72be37f66cec51e7e87fb5 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 18:41:18 +0200
Subject: Update modules/ui.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: VÃctor Gallego
---
modules/ui.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 25eba548..3b28b69c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -607,7 +607,8 @@ def create_ui(wrap_gradio_gpu_call):
aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
- aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
+ aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Aesthetic imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
+
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
with gr.Row():
--
cgit v1.2.3
From 3f5c3b981e46c16bb10948d012575b25170efb3b Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sat, 15 Oct 2022 18:41:46 +0200
Subject: Update modules/ui.py
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: VÃctor Gallego
---
modules/ui.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 3b28b69c..1f6fcdc9 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1190,7 +1190,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary')
- with gr.Tab(label="Create images embedding"):
+ with gr.Tab(label="Create aesthetic images embedding"):
+
new_embedding_name_ae = gr.Textbox(label="Name")
process_src_ae = gr.Textbox(label='Source directory')
batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
--
cgit v1.2.3
From 9a33292ce41b01252cdb8ab6214a11d274e32fa0 Mon Sep 17 00:00:00 2001
From: zhengxiaoyao0716 <1499383852@qq.com>
Date: Sat, 15 Oct 2022 01:04:47 +0800
Subject: reload javascript files when custom script bodies
---
modules/ui.py | 28 ++++++++++++++++------------
1 file changed, 16 insertions(+), 12 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index b867d40f..90b8646b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -12,7 +12,7 @@ import time
import traceback
import platform
import subprocess as sp
-from functools import reduce
+from functools import partial, reduce
import numpy as np
import torch
@@ -1491,6 +1491,7 @@ Requested path was: {f}
def reload_scripts():
modules.scripts.reload_script_body_only()
+ reload_javascript() # need to refresh the html page
reload_script_bodies.click(
fn=reload_scripts,
@@ -1738,22 +1739,25 @@ Requested path was: {f}
return demo
-with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
- javascript = f''
+def load_javascript(raw_response):
+ with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
+ javascript = f''
-jsdir = os.path.join(script_path, "javascript")
-for filename in sorted(os.listdir(jsdir)):
- with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
- javascript += f"\n"
+ jsdir = os.path.join(script_path, "javascript")
+ for filename in sorted(os.listdir(jsdir)):
+ with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
+ javascript += f"\n"
-
-if 'gradio_routes_templates_response' not in globals():
def template_response(*args, **kwargs):
- res = gradio_routes_templates_response(*args, **kwargs)
- res.body = res.body.replace(b'', f'{javascript}'.encode("utf8"))
+ res = raw_response(*args, **kwargs)
+ res.body = res.body.replace(
+ b'', f'{javascript}'.encode("utf8"))
res.init_headers()
return res
- gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
gradio.routes.templates.TemplateResponse = template_response
+
+reload_javascript = partial(load_javascript,
+ gradio.routes.templates.TemplateResponse)
+reload_javascript()
--
cgit v1.2.3
From 3d21684ee30ca5734126b8d08c05b3a0f513fe75 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sun, 16 Oct 2022 00:01:00 +0200
Subject: Add support to other img format, fixed dropbox update
---
modules/aesthetic_clip.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
index bcf2b073..68264284 100644
--- a/modules/aesthetic_clip.py
+++ b/modules/aesthetic_clip.py
@@ -8,7 +8,7 @@ import gradio as gr
import torch
from PIL import Image
from modules import shared
-from modules.shared import device, aesthetic_embeddings
+from modules.shared import device
from transformers import CLIPModel, CLIPProcessor
from tqdm.auto import tqdm
@@ -20,7 +20,7 @@ def get_all_images_in_folder(folder):
def check_is_valid_image_file(filename):
- return filename.lower().endswith(('.png', '.jpg', '.jpeg'))
+ return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))
def batched(dataset, total, n=1):
@@ -73,6 +73,6 @@ def generate_imgs_embd(name, folder, batch_size):
Aesthetic embedding saved to {html.escape(path)}
"""
shared.update_aesthetic_embeddings()
- return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
- value=sorted(aesthetic_embeddings.keys())[0] if len(
- aesthetic_embeddings) > 0 else None), res, ""
+ return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
+ value=sorted(shared.aesthetic_embeddings.keys())[0] if len(
+ shared.aesthetic_embeddings) > 0 else None), res, ""
--
cgit v1.2.3
From 9325c85f780c569d1823e422eaf51b2e497e0d3e Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sun, 16 Oct 2022 00:23:47 +0200
Subject: fixed dropbox update
---
modules/sd_hijack.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 192883b2..491312b4 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -9,7 +9,7 @@ from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
-from modules.shared import opts, device, cmd_opts, aesthetic_embeddings
+from modules.shared import opts, device, cmd_opts
from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
@@ -182,7 +182,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
image_embs_name = None
if image_embs_name is not None and self.image_embs_name != image_embs_name:
self.image_embs_name = image_embs_name
- self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device)
+ self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
self.image_embs.requires_grad_(False)
--
cgit v1.2.3
From 763b893f319cee280b86e63025eb55e7c16b02e7 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Sun, 16 Oct 2022 10:03:09 +0800
Subject: images history sorting files by date
---
modules/images_history.py | 261 ++++++++++++++++++++++++++++++++++------------
1 file changed, 196 insertions(+), 65 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index f5ef44fe..533cf51b 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -1,33 +1,74 @@
import os
import shutil
+import time
+import hashlib
+import gradio
+show_max_dates_num = 3
+system_bak_path = "webui_log_and_bak"
+def is_valid_date(date):
+ try:
+ time.strptime(date, "%Y%m%d")
+ return True
+ except:
+ return False
+def reduplicative_file_move(src, dst):
+ def same_name_file(basename, path):
+ name, ext = os.path.splitext(basename)
+ f_list = os.listdir(path)
+ max_num = 0
+ for f in f_list:
+ if len(f) <= len(basename):
+ continue
+ f_ext = f[-len(ext):] if len(ext) > 0 else ""
+ if f[:len(name)] == name and f_ext == ext:
+ if f[len(name)] == "(" and f[-len(ext)-1] == ")":
+ number = f[len(name)+1:-len(ext)-1]
+ if number.isdigit():
+ if int(number) > max_num:
+ max_num = int(number)
+ return f"{name}({max_num + 1}){ext}"
+ name = os.path.basename(src)
+ save_name = os.path.join(dst, name)
+ if not os.path.exists(save_name):
+ shutil.move(src, dst)
+ else:
+ name = same_name_file(name, dst)
+ shutil.move(src, os.path.join(dst, name))
-def traverse_all_files(output_dir, image_list, curr_dir=None):
- curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
+def traverse_all_files(curr_path, image_list, all_type=False):
try:
f_list = os.listdir(curr_path)
except:
- if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
- image_list.append(curr_dir)
+ if all_type or curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt":
+ image_list.append(curr_path)
return image_list
for file in f_list:
- file = file if curr_dir is None else os.path.join(curr_dir, file)
- file_path = os.path.join(curr_path, file)
- if file[-4:] == ".txt":
+ file = os.path.join(curr_path, file)
+ if (not all_type) and file[-4:] == ".txt":
pass
- elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
+ elif os.path.isfile(file) and file[-10:].rfind(".") > 0:
image_list.append(file)
else:
- image_list = traverse_all_files(output_dir, image_list, file)
+ image_list = traverse_all_files(file, image_list)
return image_list
-
-def get_recent_images(dir_name, page_index, step, image_index, tabname):
- page_index = int(page_index)
- f_list = os.listdir(dir_name)
+def get_recent_images(dir_name, page_index, step, image_index, tabname, date_from, date_to):
+ #print(f"turn_page {page_index}",date_from)
+ if date_from is None or date_from == "":
+ return None, 1, None, ""
image_list = []
- image_list = traverse_all_files(dir_name, image_list)
- image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
+ date_list = auto_sorting(dir_name)
+ page_index = int(page_index)
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ for date in date_list:
+ if date >= date_from and date <= date_to:
+ path = os.path.join(dir_name, date)
+ if date == today and not os.path.exists(path):
+ continue
+ image_list = traverse_all_files(path, image_list)
+
+ image_list = sorted(image_list, key=lambda file: -os.path.getctime(file))
num = 48 if tabname != "extras" else 12
max_page_index = len(image_list) // num + 1
page_index = max_page_index if page_index == -1 else page_index + step
@@ -38,40 +79,101 @@ def get_recent_images(dir_name, page_index, step, image_index, tabname):
image_index = int(image_index)
if image_index < 0 or image_index > len(image_list) - 1:
current_file = None
- hidden = None
else:
- current_file = image_list[int(image_index)]
- hidden = os.path.join(dir_name, current_file)
- return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
+ current_file = image_list[image_index]
+ return image_list, page_index, image_list, ""
+def auto_sorting(dir_name):
+ #print(f"auto sorting")
+ bak_path = os.path.join(dir_name, system_bak_path)
+ if not os.path.exists(bak_path):
+ os.mkdir(bak_path)
+ log_file = None
+ files_list = []
+ f_list = os.listdir(dir_name)
+ for file in f_list:
+ if file == system_bak_path:
+ continue
+ file_path = os.path.join(dir_name, file)
+ if not is_valid_date(file):
+ if file[-10:].rfind(".") > 0:
+ files_list.append(file_path)
+ else:
+ files_list = traverse_all_files(file_path, files_list, all_type=True)
+
+ for file in files_list:
+ date_str = time.strftime("%Y%m%d",time.localtime(os.path.getctime(file)))
+ file_path = os.path.dirname(file)
+ hash_path = hashlib.md5(file_path.encode()).hexdigest()
+ path = os.path.join(dir_name, date_str, hash_path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ if log_file is None:
+ log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a")
+ log_file.write(f"{hash_path},{file_path}\n")
+ reduplicative_file_move(file, path)
+
+ date_list = []
+ f_list = os.listdir(dir_name)
+ for f in f_list:
+ if is_valid_date(f):
+ date_list.append(f)
+ elif f == system_bak_path:
+ continue
+ else:
+ reduplicative_file_move(os.path.join(dir_name, f), bak_path)
+
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ if today not in date_list:
+ date_list.append(today)
+ return sorted(date_list)
-def first_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, 1, 0, image_index, tabname)
-def end_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, -1, 0, image_index, tabname)
+def archive_images(dir_name):
+ date_list = auto_sorting(dir_name)
+ date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0]
+ return (
+ gradio.update(visible=False),
+ gradio.update(visible=True),
+ gradio.Dropdown.update(choices=date_list, value=date_list[-1]),
+ gradio.Dropdown.update(choices=date_list, value=date_from)
+ )
+def date_to_change(dir_name, page_index, image_index, tabname, date_from, date_to):
+ #print("date_to", date_to)
+ date_list = auto_sorting(dir_name)
+ date_from_list = [date for date in date_list if date <= date_to]
+ date_from = date_from_list[0] if len(date_from_list) < show_max_dates_num else date_from_list[-show_max_dates_num]
+ image_list, page_index, image_list, _ =get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to)
+ return image_list, page_index, image_list, _, gradio.Dropdown.update(choices=date_from_list, value=date_from)
-def prev_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, -1, image_index, tabname)
+def first_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
+ return get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to)
-def next_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, 1, image_index, tabname)
+def end_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
+ return get_recent_images(dir_name, -1, 0, image_index, tabname, date_from, date_to)
-def page_index_change(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, 0, image_index, tabname)
+def prev_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
+ return get_recent_images(dir_name, page_index, -1, image_index, tabname, date_from, date_to)
-def show_image_info(num, image_path, filenames):
- # print(f"select image {num}")
- file = filenames[int(num)]
- return file, num, os.path.join(image_path, file)
+def next_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
+ return get_recent_images(dir_name, page_index, 1, image_index, tabname, date_from, date_to)
+
+
+def page_index_change(dir_name, page_index, image_index, tabname, date_from, date_to):
+ return get_recent_images(dir_name, page_index, 0, image_index, tabname, date_from, date_to)
-def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
+def show_image_info(tabname_box, num, filenames):
+ # #print(f"select image {num}")
+ file = filenames[int(num)]
+ return file, num, file
+
+def delete_image(delete_num, tabname, name, page_index, filenames, image_index):
if name == "":
return filenames, delete_num
else:
@@ -81,21 +183,19 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima
new_file_list = []
for name in filenames:
if i >= index and i < index + delete_num:
- path = os.path.join(dir_name, name)
- if os.path.exists(path):
- print(f"Delete file {path}")
- os.remove(path)
- txt_file = os.path.splitext(path)[0] + ".txt"
+ if os.path.exists(name):
+ #print(f"Delete file {name}")
+ os.remove(name)
+ txt_file = os.path.splitext(name)[0] + ".txt"
if os.path.exists(txt_file):
os.remove(txt_file)
else:
- print(f"Not exists file {path}")
+ #print(f"Not exists file {name}")
else:
new_file_list.append(name)
i += 1
return new_file_list, 1
-
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
if tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples
@@ -107,16 +207,32 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
dir_name = d[0]
for p in d[1:]:
dir_name = os.path.join(dir_name, p)
- with gr.Row():
- renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
- first_page = gr.Button('First Page')
- prev_page = gr.Button('Prev Page')
- page_index = gr.Number(value=1, label="Page Index")
- next_page = gr.Button('Next Page')
- end_page = gr.Button('End Page')
- with gr.Row(elem_id=tabname + "_images_history"):
+
+ f_list = os.listdir(dir_name)
+ sorted_flag = os.path.exists(os.path.join(dir_name, system_bak_path)) or len(f_list) == 0
+ date_list, date_from, date_to = None, None, None
+ if sorted_flag:
+ #print(sorted_flag)
+ date_list = auto_sorting(dir_name)
+ date_to = date_list[-1]
+ date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0]
+
+ with gr.Column(visible=sorted_flag) as page_panel:
with gr.Row():
+ renew_page = gr.Button('Refresh', elem_id=tabname + "_images_history_renew_page", interactive=sorted_flag)
+ first_page = gr.Button('First Page')
+ prev_page = gr.Button('Prev Page')
+ page_index = gr.Number(value=1, label="Page Index")
+ next_page = gr.Button('Next Page')
+ end_page = gr.Button('End Page')
+
+ with gr.Row(elem_id=tabname + "_images_history"):
with gr.Column(scale=2):
+ with gr.Row():
+ newest = gr.Button('Newest')
+ date_to = gr.Dropdown(choices=date_list, value=date_to, label="Date to")
+ date_from = gr.Dropdown(choices=date_list, value=date_from, label="Date from")
+
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
with gr.Row():
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
@@ -128,22 +244,31 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Row():
with gr.Column():
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
- img_file_name = gr.Textbox(label="File Name", interactive=False)
- with gr.Row():
+ img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
# hiden items
+ with gr.Row(visible=False):
+ img_path = gr.Textbox(dir_name)
+ tabname_box = gr.Textbox(tabname)
+ image_index = gr.Textbox(value=-1)
+ set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
+ filenames = gr.State()
+ hidden = gr.Image(type="pil")
+ info1 = gr.Textbox()
+ info2 = gr.Textbox()
+ with gr.Column(visible=not sorted_flag) as init_warning:
+ with gr.Row():
+ gr.Textbox("The system needs to archive the files according to the date. This requires changing the directory structure of the files",
+ label="Waring",
+ css="")
+ with gr.Row():
+ sorted_button = gr.Button('Confirme')
- img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
- tabname_box = gr.Textbox(tabname, visible=False)
- image_index = gr.Textbox(value=-1, visible=False)
- set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
- filenames = gr.State()
- hidden = gr.Image(type="pil", visible=False)
- info1 = gr.Textbox(visible=False)
- info2 = gr.Textbox(visible=False)
-
+
+
+
# turn pages
- gallery_inputs = [img_path, page_index, image_index, tabname_box]
- gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
+ gallery_inputs = [img_path, page_index, image_index, tabname_box, date_from, date_to]
+ gallery_outputs = [history_gallery, page_index, filenames, img_file_name]
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
@@ -154,15 +279,21 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
# other funcitons
- set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
+ set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, filenames], outputs=[img_file_name, image_index, hidden])
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
- delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
+ delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
-
+ date_to.change(date_to_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs + [date_from])
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
+ sorted_button.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from])
+ newest.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from])
+
+
+
+
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
with gr.Blocks(analytics_enabled=False) as images_history:
--
cgit v1.2.3
From 523140d7805c644700009b8a2483ff4eb4a22304 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sun, 16 Oct 2022 10:23:30 +0200
Subject: ui fix
---
modules/aesthetic_clip.py | 3 +--
modules/sd_hijack.py | 3 +--
modules/shared.py | 2 ++
modules/ui.py | 24 ++++++++++++++----------
4 files changed, 18 insertions(+), 14 deletions(-)
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
index 68264284..ccb35c73 100644
--- a/modules/aesthetic_clip.py
+++ b/modules/aesthetic_clip.py
@@ -74,5 +74,4 @@ def generate_imgs_embd(name, folder, batch_size):
"""
shared.update_aesthetic_embeddings()
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
- value=sorted(shared.aesthetic_embeddings.keys())[0] if len(
- shared.aesthetic_embeddings) > 0 else None), res, ""
+ value="None"), res, ""
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 01fcb78f..2de2eed5 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -392,8 +392,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
- if len(text[
- 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None:
+ if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None:
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [
[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
diff --git a/modules/shared.py b/modules/shared.py
index 3c5ffef1..e2c98b2d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -96,11 +96,13 @@ loaded_hypernetwork = None
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
+aesthetic_embeddings = aesthetic_embeddings | {"None": None}
def update_aesthetic_embeddings():
global aesthetic_embeddings
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
+ aesthetic_embeddings = aesthetic_embeddings | {"None": None}
def reload_hypernetworks():
global hypernetworks
diff --git a/modules/ui.py b/modules/ui.py
index 13ba3142..4069f0d2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -594,19 +594,23 @@ def create_ui(wrap_gradio_gpu_call):
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Group():
- aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001")
-
- aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
- aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
+ with gr.Accordion("Open for Clip Aesthetic!",open=False):
+ with gr.Row():
+ aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
+ aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
- with gr.Row():
- aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
- aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
- aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
+ with gr.Row():
+ aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001")
+ aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
+ aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()),
+ label="Aesthetic imgs embedding",
+ value="None")
- aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Aesthetic imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
+ with gr.Row():
+ aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
+ aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
+ aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
- aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
--
cgit v1.2.3
From e4f8b5f00dd33b7547cc6b76fbed26bb83b37a64 Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sun, 16 Oct 2022 10:28:21 +0200
Subject: ui fix
---
modules/sd_hijack.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 2de2eed5..5d0590af 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -178,7 +178,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.load_image_embs(image_embs_name)
def load_image_embs(self, image_embs_name):
- if image_embs_name is None or len(image_embs_name) == 0:
+ if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
image_embs_name = None
if image_embs_name is not None and self.image_embs_name != image_embs_name:
self.image_embs_name = image_embs_name
--
cgit v1.2.3
From f62905fdf928b54aa76765e5cbde8d538d494e49 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Sun, 16 Oct 2022 21:22:38 +0800
Subject: images history speed up
---
modules/images_history.py | 250 ++++++++++++++++++++++++----------------------
1 file changed, 128 insertions(+), 122 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 7fd75005..ae0b4e40 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -3,8 +3,10 @@ import shutil
import time
import hashlib
import gradio
-show_max_dates_num = 3
+
system_bak_path = "webui_log_and_bak"
+loads_files_num = 216
+num_of_imgs_per_page = 36
def is_valid_date(date):
try:
time.strptime(date, "%Y%m%d")
@@ -53,38 +55,7 @@ def traverse_all_files(curr_path, image_list, all_type=False):
image_list = traverse_all_files(file, image_list)
return image_list
-def get_recent_images(dir_name, page_index, step, image_index, tabname, date_from, date_to):
- #print(f"turn_page {page_index}",date_from)
- if date_from is None or date_from == "":
- return None, 1, None, ""
- image_list = []
- date_list = auto_sorting(dir_name)
- page_index = int(page_index)
- today = time.strftime("%Y%m%d",time.localtime(time.time()))
- for date in date_list:
- if date >= date_from and date <= date_to:
- path = os.path.join(dir_name, date)
- if date == today and not os.path.exists(path):
- continue
- image_list = traverse_all_files(path, image_list)
-
- image_list = sorted(image_list, key=lambda file: -os.path.getctime(file))
- num = 48 if tabname != "extras" else 12
- max_page_index = len(image_list) // num + 1
- page_index = max_page_index if page_index == -1 else page_index + step
- page_index = 1 if page_index < 1 else page_index
- page_index = max_page_index if page_index > max_page_index else page_index
- idx_frm = (page_index - 1) * num
- image_list = image_list[idx_frm:idx_frm + num]
- image_index = int(image_index)
- if image_index < 0 or image_index > len(image_list) - 1:
- current_file = None
- else:
- current_file = image_list[image_index]
- return image_list, page_index, image_list, ""
-
-def auto_sorting(dir_name):
- #print(f"auto sorting")
+def auto_sorting(dir_name):
bak_path = os.path.join(dir_name, system_bak_path)
if not os.path.exists(bak_path):
os.mkdir(bak_path)
@@ -126,102 +97,131 @@ def auto_sorting(dir_name):
today = time.strftime("%Y%m%d",time.localtime(time.time()))
if today not in date_list:
date_list.append(today)
- return sorted(date_list)
+ return sorted(date_list, reverse=True)
-def archive_images(dir_name):
+def archive_images(dir_name, date_to):
date_list = auto_sorting(dir_name)
- date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0]
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ date_to = today if date_to is None or date_to == "" else date_to
+ filenames = []
+ for date in date_list:
+ if date <= date_to:
+ path = os.path.join(dir_name, date)
+ if date == today and not os.path.exists(path):
+ continue
+ filenames = traverse_all_files(path, filenames)
+ if len(filenames) > loads_files_num:
+ break
+ filenames = sorted(filenames, key=lambda file: -os.path.getctime(file))
+ _, image_list, _, visible_num = get_recent_images(1, 0, filenames)
return (
gradio.update(visible=False),
gradio.update(visible=True),
- gradio.Dropdown.update(choices=date_list, value=date_list[-1]),
- gradio.Dropdown.update(choices=date_list, value=date_from)
+ gradio.Dropdown.update(choices=date_list, value=date_to),
+ date,
+ filenames,
+ 1,
+ image_list,
+ "",
+ visible_num
)
+def system_init(dir_name):
+ ret = [x for x in archive_images(dir_name, None)]
+ ret += [gradio.update(visible=False)]
+ return ret
+
+def newest_click(dir_name, date_to):
+ if date_to == "start":
+ return True, False, "start", None, None, 1, None, ""
+ else:
+ return archive_images(dir_name, time.strftime("%Y%m%d",time.localtime(time.time())))
-def date_to_change(dir_name, page_index, image_index, tabname, date_from, date_to):
- #print("date_to", date_to)
- date_list = auto_sorting(dir_name)
- date_from_list = [date for date in date_list if date <= date_to]
- date_from = date_from_list[0] if len(date_from_list) < show_max_dates_num else date_from_list[-show_max_dates_num]
- image_list, page_index, image_list, _ =get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to)
- return image_list, page_index, image_list, _, gradio.Dropdown.update(choices=date_from_list, value=date_from)
-
-def first_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
- return get_recent_images(dir_name, 1, 0, image_index, tabname, date_from, date_to)
-
-
-def end_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
- return get_recent_images(dir_name, -1, 0, image_index, tabname, date_from, date_to)
-
-
-def prev_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
- return get_recent_images(dir_name, page_index, -1, image_index, tabname, date_from, date_to)
-
-
-def next_page_click(dir_name, page_index, image_index, tabname, date_from, date_to):
- return get_recent_images(dir_name, page_index, 1, image_index, tabname, date_from, date_to)
-
-
-def page_index_change(dir_name, page_index, image_index, tabname, date_from, date_to):
- return get_recent_images(dir_name, page_index, 0, image_index, tabname, date_from, date_to)
-
-
-def show_image_info(tabname_box, num, filenames):
- # #print(f"select image {num}")
- file = filenames[int(num)]
- return file, num, file
-
-def delete_image(delete_num, tabname, name, page_index, filenames, image_index):
+def delete_image(delete_num, name, filenames, image_index, visible_num):
if name == "":
return filenames, delete_num
else:
delete_num = int(delete_num)
+ visible_num = int(visible_num)
+ image_index = int(image_index)
index = list(filenames).index(name)
i = 0
new_file_list = []
for name in filenames:
if i >= index and i < index + delete_num:
if os.path.exists(name):
- #print(f"Delete file {name}")
+ if visible_num == image_index:
+ new_file_list.append(name)
+ continue
+ print(f"Delete file {name}")
os.remove(name)
+ visible_num -= 1
txt_file = os.path.splitext(name)[0] + ".txt"
if os.path.exists(txt_file):
os.remove(txt_file)
else:
- #print(f"Not exists file {name}")
+ print(f"Not exists file {name}")
else:
new_file_list.append(name)
i += 1
- return new_file_list, 1
+ return new_file_list, 1, visible_num
+
+def get_recent_images(page_index, step, filenames):
+ page_index = int(page_index)
+ max_page_index = len(filenames) // num_of_imgs_per_page + 1
+ page_index = max_page_index if page_index == -1 else page_index + step
+ page_index = 1 if page_index < 1 else page_index
+ page_index = max_page_index if page_index > max_page_index else page_index
+ idx_frm = (page_index - 1) * num_of_imgs_per_page
+ image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page]
+ length = len(filenames)
+ visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page
+ visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
+ return page_index, image_list, "", visible_num
+
+def first_page_click(page_index, filenames):
+ return get_recent_images(1, 0, filenames)
+
+def end_page_click(page_index, filenames):
+ return get_recent_images(-1, 0, filenames)
+
+def prev_page_click(page_index, filenames):
+ return get_recent_images(page_index, -1, filenames)
+
+def next_page_click(page_index, filenames):
+ return get_recent_images(page_index, 1, filenames)
+
+def page_index_change(page_index, filenames):
+ return get_recent_images(page_index, 0, filenames)
+
+def show_image_info(tabname_box, num, page_index, filenames):
+ file = filenames[int(num) + int((page_index - 1) * num_of_imgs_per_page)]
+ return file, num, file
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
- if opts.outdir_samples != "":
- dir_name = opts.outdir_samples
- elif tabname == "txt2img":
+ if tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples
elif tabname == "img2img":
dir_name = opts.outdir_img2img_samples
elif tabname == "extras":
dir_name = opts.outdir_extras_samples
+ elif tabname == "saved":
+ dir_name = opts.outdir_save
+ if not os.path.exists(dir_name):
+ os.makedirs(dir_name)
d = dir_name.split("/")
- dir_name = "/" if dir_name.startswith("/") else d[0]
+ dir_name = d[0]
for p in d[1:]:
dir_name = os.path.join(dir_name, p)
f_list = os.listdir(dir_name)
sorted_flag = os.path.exists(os.path.join(dir_name, system_bak_path)) or len(f_list) == 0
date_list, date_from, date_to = None, None, None
- if sorted_flag:
- #print(sorted_flag)
- date_list = auto_sorting(dir_name)
- date_to = date_list[-1]
- date_from = date_list[-show_max_dates_num] if len(date_list) > show_max_dates_num else date_list[0]
with gr.Column(visible=sorted_flag) as page_panel:
with gr.Row():
- renew_page = gr.Button('Refresh', elem_id=tabname + "_images_history_renew_page", interactive=sorted_flag)
+ #renew_page = gr.Button('Refresh')
first_page = gr.Button('First Page')
prev_page = gr.Button('Prev Page')
page_index = gr.Number(value=1, label="Page Index")
@@ -231,9 +231,9 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Row(elem_id=tabname + "_images_history"):
with gr.Column(scale=2):
with gr.Row():
- newest = gr.Button('Newest')
- date_to = gr.Dropdown(choices=date_list, value=date_to, label="Date to")
- date_from = gr.Dropdown(choices=date_list, value=date_from, label="Date from")
+ newest = gr.Button('Refresh', elem_id=tabname + "_images_history_start")
+ date_from = gr.Textbox(label="Date from", interactive=False)
+ date_to = gr.Dropdown(value="start" if not sorted_flag else None, label="Date to")
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
with gr.Row():
@@ -247,66 +247,72 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Column():
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
+
# hiden items
- with gr.Row(visible=False):
+ with gr.Row(visible=False):
+ visible_img_num = gr.Number()
img_path = gr.Textbox(dir_name)
tabname_box = gr.Textbox(tabname)
image_index = gr.Textbox(value=-1)
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
filenames = gr.State()
+ all_images_list = gr.State()
hidden = gr.Image(type="pil")
info1 = gr.Textbox()
info2 = gr.Textbox()
+
with gr.Column(visible=not sorted_flag) as init_warning:
with gr.Row():
- gr.Textbox("The system needs to archive the files according to the date. This requires changing the directory structure of the files",
- label="Waring",
- css="")
+ warning = gr.Textbox(
+ label="Waring",
+ value=f"The system needs to archive the files according to the date. This requires changing the directory structure of the files.If you have doubts about this operation, you can first back up the files in the '{dir_name}' directory"
+ )
+ warning.style(height=100, width=50)
with gr.Row():
sorted_button = gr.Button('Confirme')
-
-
+ change_date_output = [init_warning, page_panel, date_to, date_from, filenames, page_index, history_gallery, img_file_name, visible_img_num]
+ sorted_button.click(system_init, inputs=[img_path], outputs=change_date_output + [sorted_button])
+ newest.click(newest_click, inputs=[img_path, date_to], outputs=change_date_output)
+ date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output)
+ date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ newest.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+
+ delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
+ delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None)
+
# turn pages
- gallery_inputs = [img_path, page_index, image_index, tabname_box, date_from, date_to]
- gallery_outputs = [history_gallery, page_index, filenames, img_file_name]
+ gallery_inputs = [page_index, filenames]
+ gallery_outputs = [page_index, history_gallery, img_file_name, visible_img_num]
+
+ first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
+ page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
- first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
+ first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
# other funcitons
- set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, filenames], outputs=[img_file_name, image_index, hidden])
- img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
- delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
+ set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden])
+ img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
- date_to.change(date_to_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs + [date_from])
- # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
+
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
- sorted_button.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from])
- newest.click(archive_images, inputs=[img_path], outputs=[init_warning, page_panel, date_to, date_from])
-
-
-
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs:
- with gr.Tab("txt2img history"):
- with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
- show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
- with gr.Tab("img2img history"):
- with gr.Blocks(analytics_enabled=False) as images_history_img2img:
- show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
- with gr.Tab("extras history"):
- with gr.Blocks(analytics_enabled=False) as images_history_img2img:
- show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
+ for tab in ["saved", "txt2img", "img2img", "extras"]:
+ with gr.Tab(tab):
+ with gr.Blocks(analytics_enabled=False) as images_history_img2img:
+ show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
return images_history
--
cgit v1.2.3
From a4de699e3c235d83b5a957d08779cb41cb0781bc Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Sun, 16 Oct 2022 22:37:12 +0800
Subject: Images history speed up
---
modules/images_history.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index ae0b4e40..94bd16a8 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -153,6 +153,7 @@ def delete_image(delete_num, name, filenames, image_index, visible_num):
if os.path.exists(name):
if visible_num == image_index:
new_file_list.append(name)
+ i += 1
continue
print(f"Delete file {name}")
os.remove(name)
@@ -221,7 +222,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Column(visible=sorted_flag) as page_panel:
with gr.Row():
- #renew_page = gr.Button('Refresh')
+ renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
first_page = gr.Button('First Page')
prev_page = gr.Button('Prev Page')
page_index = gr.Number(value=1, label="Page Index")
@@ -231,7 +232,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Row(elem_id=tabname + "_images_history"):
with gr.Column(scale=2):
with gr.Row():
- newest = gr.Button('Refresh', elem_id=tabname + "_images_history_start")
+ newest = gr.Button('Reload', elem_id=tabname + "_images_history_start")
date_from = gr.Textbox(label="Date from", interactive=False)
date_to = gr.Dropdown(value="start" if not sorted_flag else None, label="Date to")
@@ -291,12 +292,14 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
+ renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
# other funcitons
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden])
--
cgit v1.2.3
From 9324cdaa3199d65c182858785dd1eca42b192b8e Mon Sep 17 00:00:00 2001
From: MalumaDev
Date: Sun, 16 Oct 2022 17:53:56 +0200
Subject: ui fix, re organization of the code
---
modules/aesthetic_clip.py | 154 +++++++++++++++++++++++++++++++++--
modules/img2img.py | 14 +++-
modules/processing.py | 29 ++-----
modules/sd_hijack.py | 102 ++---------------------
modules/sd_models.py | 5 +-
modules/shared.py | 14 +++-
modules/textual_inversion/dataset.py | 2 +-
modules/txt2img.py | 18 ++--
modules/ui.py | 52 +++++++-----
9 files changed, 233 insertions(+), 157 deletions(-)
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
index ccb35c73..34efa931 100644
--- a/modules/aesthetic_clip.py
+++ b/modules/aesthetic_clip.py
@@ -1,3 +1,4 @@
+import copy
import itertools
import os
from pathlib import Path
@@ -7,11 +8,12 @@ import gc
import gradio as gr
import torch
from PIL import Image
-from modules import shared
-from modules.shared import device
-from transformers import CLIPModel, CLIPProcessor
+from torch import optim
-from tqdm.auto import tqdm
+from modules import shared
+from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
+from tqdm.auto import tqdm, trange
+from modules.shared import opts, device
def get_all_images_in_folder(folder):
@@ -37,12 +39,39 @@ def iter_to_batched(iterable, n=1):
yield chunk
+def create_ui():
+ with gr.Group():
+ with gr.Accordion("Open for Clip Aesthetic!", open=False):
+ with gr.Row():
+ aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
+ value=0.9)
+ aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
+
+ with gr.Row():
+ aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
+ placeholder="Aesthetic learning rate", value="0.0001")
+ aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
+ aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
+ label="Aesthetic imgs embedding",
+ value="None")
+
+ with gr.Row():
+ aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
+ placeholder="This text is used to rotate the feature space of the imgs embs",
+ value="")
+ aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
+ value=0.1)
+ aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
+
+ return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
+
+
def generate_imgs_embd(name, folder, batch_size):
# clipModel = CLIPModel.from_pretrained(
# shared.sd_model.cond_stage_model.clipModel.name_or_path
# )
- model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device)
- processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path)
+ model = shared.clip_model.to(device)
+ processor = CLIPProcessor.from_pretrained(model.name_or_path)
with torch.no_grad():
embs = []
@@ -63,7 +92,6 @@ def generate_imgs_embd(name, folder, batch_size):
torch.save(embs, path)
model = model.cpu()
- del model
del processor
del embs
gc.collect()
@@ -74,4 +102,114 @@ def generate_imgs_embd(name, folder, batch_size):
"""
shared.update_aesthetic_embeddings()
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
- value="None"), res, ""
+ value="None"), \
+ gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
+ label="Imgs embedding",
+ value="None"), res, ""
+
+
+def slerp(low, high, val):
+ low_norm = low / torch.norm(low, dim=1, keepdim=True)
+ high_norm = high / torch.norm(high, dim=1, keepdim=True)
+ omega = torch.acos((low_norm * high_norm).sum(1))
+ so = torch.sin(omega)
+ res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
+ return res
+
+
+class AestheticCLIP:
+ def __init__(self):
+ self.skip = False
+ self.aesthetic_steps = 0
+ self.aesthetic_weight = 0
+ self.aesthetic_lr = 0
+ self.slerp = False
+ self.aesthetic_text_negative = ""
+ self.aesthetic_slerp_angle = 0
+ self.aesthetic_imgs_text = ""
+
+ self.image_embs_name = None
+ self.image_embs = None
+ self.load_image_embs(None)
+
+ def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
+ aesthetic_slerp=True, aesthetic_imgs_text="",
+ aesthetic_slerp_angle=0.15,
+ aesthetic_text_negative=False):
+ self.aesthetic_imgs_text = aesthetic_imgs_text
+ self.aesthetic_slerp_angle = aesthetic_slerp_angle
+ self.aesthetic_text_negative = aesthetic_text_negative
+ self.slerp = aesthetic_slerp
+ self.aesthetic_lr = aesthetic_lr
+ self.aesthetic_weight = aesthetic_weight
+ self.aesthetic_steps = aesthetic_steps
+ self.load_image_embs(image_embs_name)
+
+ def set_skip(self, skip):
+ self.skip = skip
+
+ def load_image_embs(self, image_embs_name):
+ if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
+ image_embs_name = None
+ self.image_embs_name = None
+ if image_embs_name is not None and self.image_embs_name != image_embs_name:
+ self.image_embs_name = image_embs_name
+ self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
+ self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
+ self.image_embs.requires_grad_(False)
+
+ def __call__(self, z, remade_batch_tokens):
+ if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
+ tokenizer = shared.sd_model.cond_stage_model.tokenizer
+ if not opts.use_old_emphasis_implementation:
+ remade_batch_tokens = [
+ [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
+ remade_batch_tokens]
+
+ tokens = torch.asarray(remade_batch_tokens).to(device)
+
+ model = copy.deepcopy(shared.clip_model).to(device)
+ model.requires_grad_(True)
+ if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
+ text_embs_2 = model.get_text_features(
+ **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
+ if self.aesthetic_text_negative:
+ text_embs_2 = self.image_embs - text_embs_2
+ text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
+ img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
+ else:
+ img_embs = self.image_embs
+
+ with torch.enable_grad():
+
+ # We optimize the model to maximize the similarity
+ optimizer = optim.Adam(
+ model.text_model.parameters(), lr=self.aesthetic_lr
+ )
+
+ for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
+ text_embs = model.get_text_features(input_ids=tokens)
+ text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
+ sim = text_embs @ img_embs.T
+ loss = -sim
+ optimizer.zero_grad()
+ loss.mean().backward()
+ optimizer.step()
+
+ zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+ if opts.CLIP_stop_at_last_layers > 1:
+ zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
+ zn = model.text_model.final_layer_norm(zn)
+ else:
+ zn = zn.last_hidden_state
+ model.cpu()
+ del model
+ gc.collect()
+ torch.cuda.empty_cache()
+ zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
+ if self.slerp:
+ z = slerp(z, zn, self.aesthetic_weight)
+ else:
+ z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
+
+ return z
diff --git a/modules/img2img.py b/modules/img2img.py
index 24126774..4ed80c4b 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -56,7 +56,14 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str,
+ aesthetic_lr=0,
+ aesthetic_weight=0, aesthetic_steps=0,
+ aesthetic_imgs=None,
+ aesthetic_slerp=False,
+ aesthetic_imgs_text="",
+ aesthetic_slerp_angle=0.15,
+ aesthetic_text_negative=False, *args):
is_inpaint = mode == 1
is_batch = mode == 2
@@ -109,6 +116,11 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert,
)
+ shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
+ aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text,
+ aesthetic_slerp_angle,
+ aesthetic_text_negative)
+
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/processing.py b/modules/processing.py
index 1db26c3e..685f9fcd 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -146,7 +146,8 @@ class Processed:
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
- self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
+ self.subseed = int(
+ self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed]
@@ -332,16 +333,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
-def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0,
- aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="",
- aesthetic_slerp_angle=0.15,
- aesthetic_text_negative=False) -> Processed:
+def process_images(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
- aesthetic_lr = float(aesthetic_lr)
- aesthetic_weight = float(aesthetic_weight)
- aesthetic_steps = int(aesthetic_steps)
-
if type(p.prompt) == list:
assert (len(p.prompt) > 0)
else:
@@ -417,16 +411,10 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
# uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
# c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
- if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
- shared.sd_model.cond_stage_model.set_aesthetic_params()
+ shared.aesthetic_clip.set_skip(True)
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt],
p.steps)
- if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
- shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight,
- aesthetic_steps, aesthetic_imgs,
- aesthetic_slerp, aesthetic_imgs_text,
- aesthetic_slerp_angle,
- aesthetic_text_negative)
+ shared.aesthetic_clip.set_skip(False)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
@@ -582,7 +570,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
-
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
@@ -600,10 +587,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
- samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+ samples = samples[:, :, self.truncate_y // 2:samples.shape[2] - self.truncate_y // 2,
+ self.truncate_x // 2:samples.shape[3] - self.truncate_x // 2]
if opts.use_scale_latent_for_hires_fix:
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f),
+ mode="bilinear")
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 5d0590af..227e7670 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -29,8 +29,8 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu
-
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (
+ 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
@@ -118,33 +118,14 @@ class StableDiffusionModelHijack:
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
-def slerp(low, high, val):
- low_norm = low / torch.norm(low, dim=1, keepdim=True)
- high_norm = high / torch.norm(high, dim=1, keepdim=True)
- omega = torch.acos((low_norm * high_norm).sum(1))
- so = torch.sin(omega)
- res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
- return res
-
-
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
- self.clipModel = CLIPModel.from_pretrained(
- self.wrapped.transformer.name_or_path
- )
- del self.clipModel.vision_model
- self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path)
- self.hijack: StableDiffusionModelHijack = hijack
- self.tokenizer = wrapped.tokenizer
- # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval()
- self.image_embs_name = None
- self.image_embs = None
- self.load_image_embs(None)
self.token_mults = {}
-
+ self.hijack: StableDiffusionModelHijack = hijack
+ self.tokenizer = wrapped.tokenizer
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0]
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if
@@ -164,28 +145,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
- def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
- aesthetic_slerp=True, aesthetic_imgs_text="",
- aesthetic_slerp_angle=0.15,
- aesthetic_text_negative=False):
- self.aesthetic_imgs_text = aesthetic_imgs_text
- self.aesthetic_slerp_angle = aesthetic_slerp_angle
- self.aesthetic_text_negative = aesthetic_text_negative
- self.slerp = aesthetic_slerp
- self.aesthetic_lr = aesthetic_lr
- self.aesthetic_weight = aesthetic_weight
- self.aesthetic_steps = aesthetic_steps
- self.load_image_embs(image_embs_name)
-
- def load_image_embs(self, image_embs_name):
- if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
- image_embs_name = None
- if image_embs_name is not None and self.image_embs_name != image_embs_name:
- self.image_embs_name = image_embs_name
- self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
- self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
- self.image_embs.requires_grad_(False)
-
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_end = self.wrapped.tokenizer.eos_token_id
@@ -391,58 +350,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
- if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None:
- if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [
- [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
- remade_batch_tokens]
-
- tokens = torch.asarray(remade_batch_tokens).to(device)
-
- model = copy.deepcopy(self.clipModel).to(device)
- model.requires_grad_(True)
- if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
- text_embs_2 = model.get_text_features(
- **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
- if self.aesthetic_text_negative:
- text_embs_2 = self.image_embs - text_embs_2
- text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
- img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
- else:
- img_embs = self.image_embs
-
- with torch.enable_grad():
-
- # We optimize the model to maximize the similarity
- optimizer = optim.Adam(
- model.text_model.parameters(), lr=self.aesthetic_lr
- )
-
- for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
- text_embs = model.get_text_features(input_ids=tokens)
- text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
- sim = text_embs @ img_embs.T
- loss = -sim
- optimizer.zero_grad()
- loss.mean().backward()
- optimizer.step()
-
- zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
- if opts.CLIP_stop_at_last_layers > 1:
- zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
- zn = model.text_model.final_layer_norm(zn)
- else:
- zn = zn.last_hidden_state
- model.cpu()
- del model
-
- zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1)
- if self.slerp:
- z = slerp(z, zn, self.aesthetic_weight)
- else:
- z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
-
+ z = shared.aesthetic_clip(z, remade_batch_tokens)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 3aa21ec1..8e4ee435 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -20,7 +20,7 @@ checkpoints_loaded = collections.OrderedDict()
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
- from transformers import logging
+ from transformers import logging, CLIPModel
logging.set_verbosity_error()
except Exception:
@@ -196,6 +196,9 @@ def load_model():
sd_hijack.model_hijack.hijack(sd_model)
+ if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path:
+ shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path)
+
sd_model.eval()
print(f"Model loaded.")
diff --git a/modules/shared.py b/modules/shared.py
index e2c98b2d..e19ca779 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -3,6 +3,7 @@ import datetime
import json
import os
import sys
+from collections import OrderedDict
import gradio as gr
import tqdm
@@ -94,15 +95,15 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
-aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
- os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
-aesthetic_embeddings = aesthetic_embeddings | {"None": None}
+aesthetic_embeddings = {}
def update_aesthetic_embeddings():
global aesthetic_embeddings
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
- aesthetic_embeddings = aesthetic_embeddings | {"None": None}
+ aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings)
+
+update_aesthetic_embeddings()
def reload_hypernetworks():
global hypernetworks
@@ -381,6 +382,11 @@ sd_upscalers = []
sd_model = None
+clip_model = None
+
+from modules.aesthetic_clip import AestheticCLIP
+aesthetic_clip = AestheticCLIP()
+
progress_print_out = sys.stdout
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 68ceffe3..23bb4b6a 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -49,7 +49,7 @@ class PersonalizedBase(Dataset):
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
try:
- image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC)
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception:
continue
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 8f394d05..6cbc50fc 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,12 +1,17 @@
import modules.scripts
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
+ StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int,aesthetic_lr=0,
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int,
+ restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int,
+ subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool,
+ height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int,
+ firstphase_height: int, aesthetic_lr=0,
aesthetic_weight=0, aesthetic_steps=0,
aesthetic_imgs=None,
aesthetic_slerp=False,
@@ -41,15 +46,17 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None,
)
+ shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
+ aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle,
+ aesthetic_text_negative)
+
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
- processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text,
- aesthetic_slerp_angle,
- aesthetic_text_negative)
+ processed = process_images(p)
shared.total_tqdm.clear()
@@ -61,4 +68,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info)
-
diff --git a/modules/ui.py b/modules/ui.py
index 4069f0d2..0e5d73f0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -43,7 +43,7 @@ from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui
-import modules.aesthetic_clip
+import modules.aesthetic_clip as aesthetic_clip
import modules.images_history as img_his
@@ -593,23 +593,25 @@ def create_ui(wrap_gradio_gpu_call):
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
- with gr.Group():
- with gr.Accordion("Open for Clip Aesthetic!",open=False):
- with gr.Row():
- aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
- aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
-
- with gr.Row():
- aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001")
- aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
- aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()),
- label="Aesthetic imgs embedding",
- value="None")
-
- with gr.Row():
- aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
- aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
- aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
+ # with gr.Group():
+ # with gr.Accordion("Open for Clip Aesthetic!",open=False):
+ # with gr.Row():
+ # aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
+ # aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
+ #
+ # with gr.Row():
+ # aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001")
+ # aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
+ # aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()),
+ # label="Aesthetic imgs embedding",
+ # value="None")
+ #
+ # with gr.Row():
+ # aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
+ # aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
+ # aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
+
+ aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui()
with gr.Row():
@@ -840,6 +842,9 @@ def create_ui(wrap_gradio_gpu_call):
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui()
+
+
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False)
@@ -944,6 +949,14 @@ def create_ui(wrap_gradio_gpu_call):
inpainting_mask_invert,
img2img_batch_input_dir,
img2img_batch_output_dir,
+ aesthetic_lr_im,
+ aesthetic_weight_im,
+ aesthetic_steps_im,
+ aesthetic_imgs_im,
+ aesthetic_slerp_im,
+ aesthetic_imgs_text_im,
+ aesthetic_slerp_angle_im,
+ aesthetic_text_negative_im,
] + custom_inputs,
outputs=[
img2img_gallery,
@@ -1283,7 +1296,7 @@ def create_ui(wrap_gradio_gpu_call):
)
create_embedding_ae.click(
- fn=modules.aesthetic_clip.generate_imgs_embd,
+ fn=aesthetic_clip.generate_imgs_embd,
inputs=[
new_embedding_name_ae,
process_src_ae,
@@ -1291,6 +1304,7 @@ def create_ui(wrap_gradio_gpu_call):
],
outputs=[
aesthetic_imgs,
+ aesthetic_imgs_im,
ti_output,
ti_outcome,
]
--
cgit v1.2.3
From 9d702b16f01795c3af900e0ebd70faf4b25200f6 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 17 Oct 2022 16:11:03 +0800
Subject: fix two little bug
---
modules/images_history.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 23045df1..1ae168ca 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -133,7 +133,7 @@ def archive_images(dir_name, date_to):
date = sort_array[loads_num][2]
filenames = [x[1] for x in sort_array]
else:
- date = sort_array[loads_num][2]
+ date = sort_array[-1][2]
filenames = [x[1] for x in sort_array]
filenames = [x[1] for x in sort_array if x[2]>= date]
_, image_list, _, visible_num = get_recent_images(1, 0, filenames)
@@ -334,6 +334,6 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict):
with gr.Tab(tab):
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
- gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory") #, visible=False)
+ gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False)
return images_history
--
cgit v1.2.3
From 60251c9456f5472784862896c2f97e38feb42482 Mon Sep 17 00:00:00 2001
From: arcticfaded
Date: Mon, 17 Oct 2022 06:58:42 +0000
Subject: initial prototype by borrowing contracts
---
modules/api/api.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++
modules/processing.py | 2 +-
modules/shared.py | 2 +-
3 files changed, 62 insertions(+), 2 deletions(-)
create mode 100644 modules/api/api.py
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
new file mode 100644
index 00000000..9d7c699d
--- /dev/null
+++ b/modules/api/api.py
@@ -0,0 +1,60 @@
+from modules.api.processing import StableDiffusionProcessingAPI
+from modules.processing import StableDiffusionProcessingTxt2Img, process_images
+import modules.shared as shared
+import uvicorn
+from fastapi import FastAPI, Body, APIRouter
+from fastapi.responses import JSONResponse
+from pydantic import BaseModel, Field, Json
+import json
+import io
+import base64
+
+app = FastAPI()
+
+class TextToImageResponse(BaseModel):
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: Json
+ info: Json
+
+
+class Api:
+ def __init__(self, txt2img, img2img, run_extras, run_pnginfo):
+ self.router = APIRouter()
+ app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"])
+
+ def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
+ print(txt2imgreq)
+ p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq))
+ p.sd_model = shared.sd_model
+ print(p)
+ processed = process_images(p)
+
+ b64images = []
+ for i in processed.images:
+ buffer = io.BytesIO()
+ i.save(buffer, format="png")
+ b64images.append(base64.b64encode(buffer.getvalue()))
+
+ response = {
+ "images": b64images,
+ "info": processed.js(),
+ "parameters": json.dumps(vars(txt2imgreq))
+ }
+
+
+ return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
+
+
+
+ def img2imgendoint(self):
+ raise NotImplementedError
+
+ def extrasendoint(self):
+ raise NotImplementedError
+
+ def pnginfoendoint(self):
+ raise NotImplementedError
+
+ def launch(self, server_name, port):
+ app.include_router(self.router)
+ uvicorn.run(app, host=server_name, port=port)
\ No newline at end of file
diff --git a/modules/processing.py b/modules/processing.py
index deb6125e..4a7c6ccc 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -723,4 +723,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
del x
devices.torch_gc()
- return samples
+ return samples
\ No newline at end of file
diff --git a/modules/shared.py b/modules/shared.py
index c2775603..6c6405fd 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -74,7 +74,7 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
-
+parser.add_argument("--api", action='store_true', help="use api=True to launch the api instead of the webui")
cmd_opts = parser.parse_args()
restricted_opts = [
--
cgit v1.2.3
From 9e02812afd10582f00a7fbbfa63c8f9188678e26 Mon Sep 17 00:00:00 2001
From: arcticfaded
Date: Mon, 17 Oct 2022 07:02:08 +0000
Subject: pydantic instrumentation
---
modules/api/processing.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 99 insertions(+)
create mode 100644 modules/api/processing.py
(limited to 'modules')
diff --git a/modules/api/processing.py b/modules/api/processing.py
new file mode 100644
index 00000000..459a8f49
--- /dev/null
+++ b/modules/api/processing.py
@@ -0,0 +1,99 @@
+from inflection import underscore
+from typing import Any, Dict, Optional
+from pydantic import BaseModel, Field, create_model
+from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+import inspect
+
+
+class ModelDef(BaseModel):
+ """Assistance Class for Pydantic Dynamic Model Generation"""
+
+ field: str
+ field_alias: str
+ field_type: Any
+ field_value: Any
+
+
+class pydanticModelGenerator:
+ """
+ Takes source_data:Dict ( a single instance example of something like a JSON node) and self generates a pythonic data model with Alias to original source field names. This makes it easy to popuate or export to other systems yet handle the data in a pythonic way.
+ Being a pydantic datamodel all the richness of pydantic data validation is available and these models can easily be used in FastAPI and or a ORM
+
+ It does not process full JSON data structures but takes simple JSON document with basic elements
+
+ Provide a model_name, an example of JSON data and a dict of type overrides
+
+ Example:
+
+ source_data = {'Name': '48 Rainbow Rd',
+ 'GroupAddressStyle': 'ThreeLevel',
+ 'LastModified': '2020-12-21T07:02:51.2400232Z',
+ 'ProjectStart': '2020-12-03T07:36:03.324856Z',
+ 'Comment': '',
+ 'CompletionStatus': 'Editing',
+ 'LastUsedPuid': '955',
+ 'Guid': '0c85957b-c2ae-4985-9752-b300ab385b36'}
+
+ source_overrides = {'Guid':{'type':uuid.UUID},
+ 'LastModified':{'type':datetime },
+ 'ProjectStart':{'type':datetime },
+ }
+ source_optionals = {"Comment":True}
+
+ #create Model
+ model_Project=pydanticModelGenerator(
+ model_name="Project",
+ source_data=source_data,
+ overrides=source_overrides,
+ optionals=source_optionals).generate_model()
+
+ #create instance using DynamicModel
+ project_instance=model_Project(**project_info)
+
+ """
+
+ def __init__(
+ self,
+ model_name: str = None,
+ source_data: str = None,
+ params: Dict = {},
+ overrides: Dict = {},
+ optionals: Dict = {},
+ ):
+ def field_type_generator(k, v, overrides, optionals):
+ print(k, v)
+ field_type = str if not overrides.get(k) else overrides[k]["type"]
+ if v is None:
+ field_type = Any
+ else:
+ field_type = type(v)
+
+ return Optional[field_type]
+
+ self._model_name = model_name
+ self._json_data = source_data
+ self._model_def = [
+ ModelDef(
+ field=underscore(k),
+ field_alias=k,
+ field_type=field_type_generator(k, v, overrides, optionals),
+ field_value=v
+ )
+ for (k,v) in source_data.items() if k in params
+ ]
+
+ def generate_model(self):
+ """
+ Creates a pydantic BaseModel
+ from the json and overrides provided at initialization
+ """
+ fields = {
+ d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
+ }
+ DynamicModel = create_model(self._model_name, **fields)
+ DynamicModel.__config__.allow_population_by_field_name = True
+ return DynamicModel
+
+StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing",
+ StableDiffusionProcessing().__dict__,
+ inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model()
\ No newline at end of file
--
cgit v1.2.3
From 832b490e5173f78c4d3aa7ca9ca9ac794d140664 Mon Sep 17 00:00:00 2001
From: Jonathan
Date: Mon, 17 Oct 2022 03:18:41 -0400
Subject: Update processing.py
---
modules/api/processing.py | 41 +++++------------------------------------
1 file changed, 5 insertions(+), 36 deletions(-)
(limited to 'modules')
diff --git a/modules/api/processing.py b/modules/api/processing.py
index 459a8f49..4c3d0bd0 100644
--- a/modules/api/processing.py
+++ b/modules/api/processing.py
@@ -16,46 +16,15 @@ class ModelDef(BaseModel):
class pydanticModelGenerator:
"""
- Takes source_data:Dict ( a single instance example of something like a JSON node) and self generates a pythonic data model with Alias to original source field names. This makes it easy to popuate or export to other systems yet handle the data in a pythonic way.
- Being a pydantic datamodel all the richness of pydantic data validation is available and these models can easily be used in FastAPI and or a ORM
-
- It does not process full JSON data structures but takes simple JSON document with basic elements
-
- Provide a model_name, an example of JSON data and a dict of type overrides
-
- Example:
-
- source_data = {'Name': '48 Rainbow Rd',
- 'GroupAddressStyle': 'ThreeLevel',
- 'LastModified': '2020-12-21T07:02:51.2400232Z',
- 'ProjectStart': '2020-12-03T07:36:03.324856Z',
- 'Comment': '',
- 'CompletionStatus': 'Editing',
- 'LastUsedPuid': '955',
- 'Guid': '0c85957b-c2ae-4985-9752-b300ab385b36'}
-
- source_overrides = {'Guid':{'type':uuid.UUID},
- 'LastModified':{'type':datetime },
- 'ProjectStart':{'type':datetime },
- }
- source_optionals = {"Comment":True}
-
- #create Model
- model_Project=pydanticModelGenerator(
- model_name="Project",
- source_data=source_data,
- overrides=source_overrides,
- optionals=source_optionals).generate_model()
-
- #create instance using DynamicModel
- project_instance=model_Project(**project_info)
-
+ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
+ source_data is a snapshot of the default values produced by the class
+ params are the names of the actual keys required by __init__
"""
def __init__(
self,
model_name: str = None,
- source_data: str = None,
+ source_data: {} = {},
params: Dict = {},
overrides: Dict = {},
optionals: Dict = {},
@@ -96,4 +65,4 @@ class pydanticModelGenerator:
StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing",
StableDiffusionProcessing().__dict__,
- inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model()
\ No newline at end of file
+ inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model()
--
cgit v1.2.3
From 99013ba68a5fe1bde3621632e5539c03562a3ae8 Mon Sep 17 00:00:00 2001
From: Jonathan
Date: Mon, 17 Oct 2022 03:20:17 -0400
Subject: Update processing.py
---
modules/api/processing.py | 1 -
1 file changed, 1 deletion(-)
(limited to 'modules')
diff --git a/modules/api/processing.py b/modules/api/processing.py
index 4c3d0bd0..e4df93c5 100644
--- a/modules/api/processing.py
+++ b/modules/api/processing.py
@@ -30,7 +30,6 @@ class pydanticModelGenerator:
optionals: Dict = {},
):
def field_type_generator(k, v, overrides, optionals):
- print(k, v)
field_type = str if not overrides.get(k) else overrides[k]["type"]
if v is None:
field_type = Any
--
cgit v1.2.3
From 71d42bb44b257f3fb274c3ad5075a195281ff915 Mon Sep 17 00:00:00 2001
From: Jonathan
Date: Mon, 17 Oct 2022 03:22:19 -0400
Subject: Update api.py
---
modules/api/api.py | 11 +----------
1 file changed, 1 insertion(+), 10 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 9d7c699d..4d9619a8 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -23,10 +23,8 @@ class Api:
app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
- print(txt2imgreq)
p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq))
p.sd_model = shared.sd_model
- print(p)
processed = process_images(p)
b64images = []
@@ -34,13 +32,6 @@ class Api:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
-
- response = {
- "images": b64images,
- "info": processed.js(),
- "parameters": json.dumps(vars(txt2imgreq))
- }
-
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
@@ -57,4 +48,4 @@ class Api:
def launch(self, server_name, port):
app.include_router(self.router)
- uvicorn.run(app, host=server_name, port=port)
\ No newline at end of file
+ uvicorn.run(app, host=server_name, port=port)
--
cgit v1.2.3
From d42125baf62880854ad06af06c15c23e7e50cca6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Mon, 17 Oct 2022 11:50:20 +0300
Subject: add missing requirement for api and fix some typos
---
modules/api/api.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 4d9619a8..fd09d352 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -18,7 +18,7 @@ class TextToImageResponse(BaseModel):
class Api:
- def __init__(self, txt2img, img2img, run_extras, run_pnginfo):
+ def __init__(self):
self.router = APIRouter()
app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"])
--
cgit v1.2.3
From 8c6a981d5d9ef30381ac2327460285111550acbc Mon Sep 17 00:00:00 2001
From: Michoko
Date: Mon, 17 Oct 2022 11:05:05 +0200
Subject: Added dark mode switch
Launch the UI in dark mode with the --dark-mode switch
---
modules/shared.py | 2 +-
modules/ui.py | 2 ++
2 files changed, 3 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index c2775603..cbf158e4 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -69,13 +69,13 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
+parser.add_argument("--dark-mode", action='store_true', help="launches the UI in dark mode", default=False)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
-
cmd_opts = parser.parse_args()
restricted_opts = [
"samples_filename_pattern",
diff --git a/modules/ui.py b/modules/ui.py
index 43dc88fc..a0cd052e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1783,6 +1783,8 @@ for filename in sorted(os.listdir(jsdir)):
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
javascript += f"\n"
+if cmd_opts.dark_mode:
+ javascript += "\n\n"
if 'gradio_routes_templates_response' not in globals():
def template_response(*args, **kwargs):
--
cgit v1.2.3
From c408a0b41cfffde184cad35b2d97346342947d83 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 17 Oct 2022 22:28:43 +0800
Subject: fix two bug
---
modules/images_history.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 1ae168ca..10e5b970 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -181,7 +181,8 @@ def delete_image(delete_num, name, filenames, image_index, visible_num):
return new_file_list, 1, visible_num
def save_image(file_name):
- shutil.copy2(file_name, opts.outdir_save)
+ if file_name is not None and os.path.exists(file_name):
+ shutil.copy2(file_name, opts.outdir_save)
def get_recent_images(page_index, step, filenames):
page_index = int(page_index)
@@ -327,7 +328,6 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict):
opts = sys_opts
loads_files_num = int(opts.images_history_num_per_page)
num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
- backup_flag = opts.images_history_backup
with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs:
for tab in ["txt2img", "img2img", "extras", "saved"]:
--
cgit v1.2.3
From 2272cf2f35fafd5cd486bfb4ee89df5bbc625b97 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 17 Oct 2022 23:04:42 +0800
Subject: fix two bug
---
modules/images_history.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 10e5b970..1c1790a4 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -133,7 +133,7 @@ def archive_images(dir_name, date_to):
date = sort_array[loads_num][2]
filenames = [x[1] for x in sort_array]
else:
- date = sort_array[-1][2]
+ date = None if len(sort_array) == 0 else sort_array[-1][2]
filenames = [x[1] for x in sort_array]
filenames = [x[1] for x in sort_array if x[2]>= date]
_, image_list, _, visible_num = get_recent_images(1, 0, filenames)
--
cgit v1.2.3
From 2b5b62e768d892773a7ec1d5e8d8cea23aae1254 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Mon, 17 Oct 2022 23:14:03 +0800
Subject: fix two bug
---
modules/images_history.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 1c1790a4..20324557 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -44,7 +44,7 @@ def traverse_all_files(curr_path, image_list, all_type=False):
return image_list
for file in f_list:
file = os.path.join(curr_path, file)
- if (not all_type) and file[-4:] == ".txt":
+ if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"):
pass
elif os.path.isfile(file) and file[-10:].rfind(".") > 0:
image_list.append(file)
@@ -182,7 +182,7 @@ def delete_image(delete_num, name, filenames, image_index, visible_num):
def save_image(file_name):
if file_name is not None and os.path.exists(file_name):
- shutil.copy2(file_name, opts.outdir_save)
+ shutil.copy(file_name, opts.outdir_save)
def get_recent_images(page_index, step, filenames):
page_index = int(page_index)
--
cgit v1.2.3
From 665beebc0825a6fad410c8252f27f6f6f0bd900b Mon Sep 17 00:00:00 2001
From: Michoko
Date: Mon, 17 Oct 2022 18:24:24 +0200
Subject: Use of a --theme argument for more flexibility
Added possibility to set the theme (light or dark)
---
modules/shared.py | 2 +-
modules/ui.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index cbf158e4..fa084c69 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -69,7 +69,7 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
-parser.add_argument("--dark-mode", action='store_true', help="launches the UI in dark mode", default=False)
+parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
diff --git a/modules/ui.py b/modules/ui.py
index a0cd052e..d41715fa 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1783,8 +1783,8 @@ for filename in sorted(os.listdir(jsdir)):
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
javascript += f"\n"
-if cmd_opts.dark_mode:
- javascript += "\n\n"
+if cmd_opts.theme is not None:
+ javascript += f"\n\n"
if 'gradio_routes_templates_response' not in globals():
def template_response(*args, **kwargs):
--
cgit v1.2.3
From d62ef76614624cda99d842a2900242d5b7923eda Mon Sep 17 00:00:00 2001
From: guaneec
Date: Tue, 18 Oct 2022 03:09:50 +0800
Subject: Don't eat colons in booru tags
---
modules/deepbooru.py | 2 --
1 file changed, 2 deletions(-)
(limited to 'modules')
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 4ad334a1..de16b13f 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -157,8 +157,6 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold:
- # note: tag_outformat will still have a colon if include_ranks is True
- tag_outformat = tag.replace(':', ' ')
if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ')
if use_escape:
--
cgit v1.2.3
From f80e914ac4aa69a9783b4040813253500b34d925 Mon Sep 17 00:00:00 2001
From: arcticfaded
Date: Mon, 17 Oct 2022 19:10:36 +0000
Subject: example API working with gradio
---
modules/api/api.py | 9 ++++++--
modules/api/processing.py | 56 ++++++++++++++++++++++++++++++++---------------
modules/processing.py | 22 +++++++++++++------
3 files changed, 60 insertions(+), 27 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index fd09d352..5e86c3bf 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -23,8 +23,13 @@ class Api:
app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
- p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq))
- p.sd_model = shared.sd_model
+ populate = txt2imgreq.copy(update={ # Override __init__ params
+ "sd_model": shared.sd_model,
+ "sampler_index": 0,
+ }
+ )
+ p = StableDiffusionProcessingTxt2Img(**vars(populate))
+ # Override object param
processed = process_images(p)
b64images = []
diff --git a/modules/api/processing.py b/modules/api/processing.py
index e4df93c5..b6798241 100644
--- a/modules/api/processing.py
+++ b/modules/api/processing.py
@@ -5,6 +5,24 @@ from modules.processing import StableDiffusionProcessing, Processed, StableDiffu
import inspect
+API_NOT_ALLOWED = [
+ "self",
+ "kwargs",
+ "sd_model",
+ "outpath_samples",
+ "outpath_grids",
+ "sampler_index",
+ "do_not_save_samples",
+ "do_not_save_grid",
+ "extra_generation_params",
+ "overlay_images",
+ "do_not_reload_embeddings",
+ "seed_enable_extras",
+ "prompt_for_display",
+ "sampler_noise_scheduler_override",
+ "ddim_discretize"
+]
+
class ModelDef(BaseModel):
"""Assistance Class for Pydantic Dynamic Model Generation"""
@@ -14,7 +32,7 @@ class ModelDef(BaseModel):
field_value: Any
-class pydanticModelGenerator:
+class PydanticModelGenerator:
"""
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
source_data is a snapshot of the default values produced by the class
@@ -24,30 +42,33 @@ class pydanticModelGenerator:
def __init__(
self,
model_name: str = None,
- source_data: {} = {},
- params: Dict = {},
- overrides: Dict = {},
- optionals: Dict = {},
+ class_instance = None
):
- def field_type_generator(k, v, overrides, optionals):
- field_type = str if not overrides.get(k) else overrides[k]["type"]
- if v is None:
- field_type = Any
- else:
- field_type = type(v)
+ def field_type_generator(k, v):
+ # field_type = str if not overrides.get(k) else overrides[k]["type"]
+ # print(k, v.annotation, v.default)
+ field_type = v.annotation
return Optional[field_type]
+ def merge_class_params(class_):
+ all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
+ parameters = {}
+ for classes in all_classes:
+ parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
+ return parameters
+
+
self._model_name = model_name
- self._json_data = source_data
+ self._class_data = merge_class_params(class_instance)
self._model_def = [
ModelDef(
field=underscore(k),
field_alias=k,
- field_type=field_type_generator(k, v, overrides, optionals),
- field_value=v
+ field_type=field_type_generator(k, v),
+ field_value=v.default
)
- for (k,v) in source_data.items() if k in params
+ for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
def generate_model(self):
@@ -60,8 +81,7 @@ class pydanticModelGenerator:
}
DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True
+ DynamicModel.__config__.allow_mutation = True
return DynamicModel
-StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing",
- StableDiffusionProcessing().__dict__,
- inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model()
+StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model()
diff --git a/modules/processing.py b/modules/processing.py
index 4a7c6ccc..024a4fc3 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps
import random
import cv2
from skimage import exposure
+from typing import Any, Dict, List, Optional
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
@@ -51,9 +52,15 @@ def get_correct_sampler(p):
return sd_samplers.samplers
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img
+ elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
+ return sd_samplers.samplers
-class StableDiffusionProcessing:
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None, do_not_reload_embeddings=False):
+class StableDiffusionProcessing():
+ """
+ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
+
+ """
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -86,10 +93,10 @@ class StableDiffusionProcessing:
self.denoising_strength: float = 0
self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize
- self.s_churn = opts.s_churn
- self.s_tmin = opts.s_tmin
- self.s_tmax = float('inf') # not representable as a standard ui option
- self.s_noise = opts.s_noise
+ self.s_churn = s_churn or opts.s_churn
+ self.s_tmin = s_tmin or opts.s_tmin
+ self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
+ self.s_noise = s_noise or opts.s_noise
if not seed_enable_extras:
self.subseed = -1
@@ -97,6 +104,7 @@ class StableDiffusionProcessing:
self.seed_resize_from_h = 0
self.seed_resize_from_w = 0
+
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -497,7 +505,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs):
+ def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
--
cgit v1.2.3
From 2e28c841f438b2090caac2b9a54eb62ddbda837c Mon Sep 17 00:00:00 2001
From: guaneec
Date: Tue, 18 Oct 2022 03:15:41 +0800
Subject: Oops
---
modules/deepbooru.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index de16b13f..8914662d 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -157,6 +157,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold:
+ tag_outformat = tag
if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ')
if use_escape:
--
cgit v1.2.3
From f29b16bad19b6332a15b2ef439864d866277fffb Mon Sep 17 00:00:00 2001
From: arcticfaded
Date: Mon, 17 Oct 2022 20:36:14 +0000
Subject: prevent API from saving
---
modules/api/api.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 5e86c3bf..ce72c5ee 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -26,6 +26,8 @@ class Api:
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": 0,
+ "do_not_save_samples": True,
+ "do_not_save_grid": True
}
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
--
cgit v1.2.3
From c3851a853d99ad35ccedcdd8dbeb6cfbe273439b Mon Sep 17 00:00:00 2001
From: Ryan Voots
Date: Mon, 17 Oct 2022 12:49:33 -0400
Subject: Re-use webui fastapi application rather than requiring one or the
other, not both.
---
modules/api/api.py | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index ce72c5ee..8781cd86 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -2,15 +2,13 @@ from modules.api.processing import StableDiffusionProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
import modules.shared as shared
import uvicorn
-from fastapi import FastAPI, Body, APIRouter
+from fastapi import Body, APIRouter
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
import json
import io
import base64
-app = FastAPI()
-
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
@@ -18,7 +16,7 @@ class TextToImageResponse(BaseModel):
class Api:
- def __init__(self):
+ def __init__(self, app):
self.router = APIRouter()
app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"])
--
cgit v1.2.3
From 247aeb3aaaf2925c7d68a9cf47c975f3e6d3dd33 Mon Sep 17 00:00:00 2001
From: Ryan Voots
Date: Mon, 17 Oct 2022 12:50:45 -0400
Subject: Put API under /sdapi/ so that routing is simpler in the future. This
means that one could allow access to /sdapi/ but not the webui.
---
modules/api/api.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 8781cd86..14613d8c 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -18,7 +18,7 @@ class TextToImageResponse(BaseModel):
class Api:
def __init__(self, app):
self.router = APIRouter()
- app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"])
+ app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
populate = txt2imgreq.copy(update={ # Override __init__ params
--
cgit v1.2.3
From 1df3ff25e6fe2e3f308e45f7a6dd37fb4f1988e6 Mon Sep 17 00:00:00 2001
From: Ryan Voots
Date: Mon, 17 Oct 2022 12:58:34 -0400
Subject: Add --nowebui as a means of disabling the webui and run on the other
port
---
modules/shared.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 6c6405fd..8b436970 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -74,7 +74,8 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
-parser.add_argument("--api", action='store_true', help="use api=True to launch the api instead of the webui")
+parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
+parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
cmd_opts = parser.parse_args()
restricted_opts = [
--
cgit v1.2.3
From 8d5d863a9d11850464fdb6b64f34602803c15ccc Mon Sep 17 00:00:00 2001
From: arcticfaded
Date: Tue, 18 Oct 2022 06:51:53 +0000
Subject: gradio and FastAPI
---
modules/api/api.py | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 14613d8c..ce98cb8c 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -16,9 +16,11 @@ class TextToImageResponse(BaseModel):
class Api:
- def __init__(self, app):
+ def __init__(self, app, queue_lock):
self.router = APIRouter()
- app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
+ self.app = app
+ self.queue_lock = queue_lock
+ self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
populate = txt2imgreq.copy(update={ # Override __init__ params
@@ -30,7 +32,8 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
- processed = process_images(p)
+ with self.queue_lock:
+ processed = process_images(p)
b64images = []
for i in processed.images:
@@ -52,5 +55,5 @@ class Api:
raise NotImplementedError
def launch(self, server_name, port):
- app.include_router(self.router)
- uvicorn.run(app, host=server_name, port=port)
+ self.app.include_router(self.router)
+ uvicorn.run(self.app, host=server_name, port=port)
--
cgit v1.2.3
From 8b02662215917d39f76f86b703a322818d5a8ad4 Mon Sep 17 00:00:00 2001
From: trufty
Date: Mon, 17 Oct 2022 10:58:21 -0400
Subject: Disable auto weights swap with config option
---
modules/shared.py | 1 +
modules/ui.py | 4 ++++
2 files changed, 5 insertions(+)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 9603d26e..8a1d1881 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -266,6 +266,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
+ "disable_weights_auto_swap": OptionInfo(False, "Disable auto swapping weights to match model hash in prompts"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
diff --git a/modules/ui.py b/modules/ui.py
index 1dae4a65..75eb0b0c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -542,6 +542,10 @@ def apply_setting(key, value):
if value is None:
return gr.update()
+ # dont allow model to be swapped when model hash exists in prompt
+ if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
+ return gr.update()
+
if key == "sd_model_checkpoint":
ckpt_info = sd_models.get_closet_checkpoint_match(value)
--
cgit v1.2.3
From d2f459c5cf9f728256775dc1c3380c7e9a7e27fb Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 18 Oct 2022 14:22:52 +0300
Subject: clarify the comment for the new option from #2959 and move it to UI
section.
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 8a1d1881..c0d87168 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -266,7 +266,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
- "disable_weights_auto_swap": OptionInfo(False, "Disable auto swapping weights to match model hash in prompts"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
@@ -294,6 +293,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
+ "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
--
cgit v1.2.3
From 97d3ba3941536215ea15431886c7f28300a9d915 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E3=81=B5=E3=81=81?=
<34892635+fa0311@users.noreply.github.com>
Date: Tue, 18 Oct 2022 17:29:42 +0900
Subject: Add scripts to ui-config,json
---
modules/scripts.py | 15 +++++++++++++--
modules/ui.py | 5 +++++
2 files changed, 18 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/scripts.py b/modules/scripts.py
index ac66d448..3402066d 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -96,6 +96,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptRunner:
def __init__(self):
self.scripts = []
+ self.titles = []
def setup_ui(self, is_img2img):
for script_class, path in scripts_data:
@@ -107,9 +108,10 @@ class ScriptRunner:
self.scripts.append(script)
- titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
- dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
+ dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
+ dropdown.save_to_config = True
inputs = [dropdown]
for script in self.scripts:
@@ -139,6 +141,15 @@ class ScriptRunner:
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
+ def init_field(title):
+ if title == "None":
+ return
+ script_index = self.titles.index(title)
+ script = self.scripts[script_index]
+ for i in range(script.args_from, script.args_to):
+ inputs[i].visible = True
+
+ dropdown.init_field = init_field
dropdown.change(
fn=select_script,
inputs=[dropdown],
diff --git a/modules/ui.py b/modules/ui.py
index 75eb0b0c..39afbc4e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1753,6 +1753,11 @@ Requested path was: {f}
print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
else:
setattr(obj, field, saved_value)
+ if getattr(x, 'init_field', False):
+ try:
+ x.init_field(saved_value)
+ except Exception:
+ print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
apply_field(x, 'visible')
--
cgit v1.2.3
From de29ec0743fcfb141d8891a3ccbd537ea71bf5b4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E3=81=B5=E3=81=81?=
<34892635+fa0311@users.noreply.github.com>
Date: Tue, 18 Oct 2022 18:15:00 +0900
Subject: Remove exception handling
---
modules/ui.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 39afbc4e..b38bfb3f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1754,10 +1754,7 @@ Requested path was: {f}
else:
setattr(obj, field, saved_value)
if getattr(x, 'init_field', False):
- try:
- x.init_field(saved_value)
- except Exception:
- print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
+ x.init_field(saved_value)
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
apply_field(x, 'visible')
--
cgit v1.2.3
From 3003438088502774628656790d83fc8074d51ab4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E3=81=B5=E3=81=81?=
<34892635+fa0311@users.noreply.github.com>
Date: Tue, 18 Oct 2022 18:51:57 +0900
Subject: Add visible for dropdown
---
modules/ui.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index b38bfb3f..fb6eb5a0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1737,7 +1737,7 @@ Requested path was: {f}
print(traceback.format_exc(), file=sys.stderr)
def loadsave(path, x):
- def apply_field(obj, field, condition=None):
+ def apply_field(obj, field, condition=None, init_field=None):
key = path + "/" + field
if getattr(obj,'custom_script_source',None) is not None:
@@ -1753,8 +1753,8 @@ Requested path was: {f}
print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
else:
setattr(obj, field, saved_value)
- if getattr(x, 'init_field', False):
- x.init_field(saved_value)
+ if init_field is not None:
+ init_field(saved_value)
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
apply_field(x, 'visible')
@@ -1780,7 +1780,8 @@ Requested path was: {f}
# Since there are many dropdowns that shouldn't be saved,
# we only mark dropdowns that should be saved.
if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
- apply_field(x, 'value', lambda val: val in x.choices)
+ apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
+ apply_field(x, 'visible')
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
--
cgit v1.2.3
From 02622b19191f5f5112db7633c0630e5c7df1b2f7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E3=81=B5=E3=81=81?=
<34892635+fa0311@users.noreply.github.com>
Date: Tue, 18 Oct 2022 18:52:27 +0900
Subject: update scripts.py
---
modules/scripts.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/scripts.py b/modules/scripts.py
index 3402066d..1039fa9c 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -142,7 +142,7 @@ class ScriptRunner:
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
def init_field(title):
- if title == "None":
+ if title == 'None':
return
script_index = self.titles.index(title)
script = self.scripts[script_index]
--
cgit v1.2.3
From 4c605c5174a9b211c3a88e9aff5f5be92b53fd92 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 16 Oct 2022 17:24:06 +0100
Subject: add shared option for update check
---
modules/shared.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index c0d87168..50dc46ae 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -76,6 +76,7 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
+parser.add_argument("--update-check", action='store_true', help="enable http check to confirm that the currently running version is the most recent release.", default=False)
cmd_opts = parser.parse_args()
restricted_opts = [
--
cgit v1.2.3
From eb299527b1e5d1f83a14641647fca72e8fb305ac Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Tue, 18 Oct 2022 20:14:11 +0800
Subject: Image browser
---
modules/images_history.py | 227 ++++++++++++++++++++++++++++++----------------
modules/shared.py | 7 +-
modules/ui.py | 2 +-
3 files changed, 154 insertions(+), 82 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index 20324557..d56f3a25 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -4,6 +4,7 @@ import time
import hashlib
import gradio
system_bak_path = "webui_log_and_bak"
+browser_tabname = "custom"
def is_valid_date(date):
try:
time.strptime(date, "%Y%m%d")
@@ -99,13 +100,15 @@ def auto_sorting(dir_name):
date_list.append(today)
return sorted(date_list, reverse=True)
-def archive_images(dir_name, date_to):
+def archive_images(dir_name, date_to):
+
filenames = []
loads_num =int(opts.images_history_num_per_page * opts.images_history_pages_num)
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ date_to = today if date_to is None or date_to == "" else date_to
+ date_to_bak = date_to
if opts.images_history_reconstruct_directory:
- date_list = auto_sorting(dir_name)
- today = time.strftime("%Y%m%d",time.localtime(time.time()))
- date_to = today if date_to is None or date_to == "" else date_to
+ date_list = auto_sorting(dir_name)
for date in date_list:
if date <= date_to:
path = os.path.join(dir_name, date)
@@ -120,7 +123,7 @@ def archive_images(dir_name, date_to):
tmparray = [(os.path.getmtime(file), file) for file in filenames ]
date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
filenames = []
- date_list = {}
+ date_list = {date_to:None}
date = time.strftime("%Y%m%d",time.localtime(time.time()))
for t, f in tmparray:
date = time.strftime("%Y%m%d",time.localtime(t))
@@ -133,22 +136,29 @@ def archive_images(dir_name, date_to):
date = sort_array[loads_num][2]
filenames = [x[1] for x in sort_array]
else:
- date = None if len(sort_array) == 0 else sort_array[-1][2]
+ date = date_to if len(sort_array) == 0 else sort_array[-1][2]
filenames = [x[1] for x in sort_array]
- filenames = [x[1] for x in sort_array if x[2]>= date]
- _, image_list, _, visible_num = get_recent_images(1, 0, filenames)
+ filenames = [x[1] for x in sort_array if x[2]>= date]
+ num = len(filenames)
+ last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
+ date = date[:4] + "-" + date[4:6] + "-" + date[6:8]
+ date_to_bak = date_to_bak[:4] + "-" + date_to_bak[4:6] + "-" + date_to_bak[6:8]
+ load_info = f"Loaded {(num + 1) // opts.images_history_pages_num} pades, {num} images, during {date} - {date_to_bak}"
+ _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
return (
gradio.Dropdown.update(choices=date_list, value=date_to),
- date,
+ load_info,
filenames,
1,
image_list,
"",
- visible_num
+ "",
+ visible_num,
+ last_date_from
)
-def newest_click(dir_name, date_to):
- return archive_images(dir_name, time.strftime("%Y%m%d",time.localtime(time.time())))
+
+
def delete_image(delete_num, name, filenames, image_index, visible_num):
if name == "":
@@ -196,7 +206,29 @@ def get_recent_images(page_index, step, filenames):
length = len(filenames)
visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page
visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
- return page_index, image_list, "", visible_num
+ return page_index, image_list, "", "", visible_num
+
+def newest_click(date_to):
+ if date_to is None:
+ return time.strftime("%Y%m%d",time.localtime(time.time())), []
+ else:
+ return None, []
+def forward_click(last_date_from, date_to_recorder):
+ if len(date_to_recorder) == 0:
+ return None, []
+ if last_date_from == date_to_recorder[-1]:
+ date_to_recorder = date_to_recorder[:-1]
+ if len(date_to_recorder) == 0:
+ return None, []
+ return date_to_recorder[-1], date_to_recorder[:-1]
+
+def backward_click(last_date_from, date_to_recorder):
+ if last_date_from is None or last_date_from == "":
+ return time.strftime("%Y%m%d",time.localtime(time.time())), []
+ if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]:
+ date_to_recorder.append(last_date_from)
+ return last_date_from, date_to_recorder
+
def first_page_click(page_index, filenames):
return get_recent_images(1, 0, filenames)
@@ -214,13 +246,33 @@ def page_index_change(page_index, filenames):
return get_recent_images(page_index, 0, filenames)
def show_image_info(tabname_box, num, page_index, filenames):
- file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
- return file, num, file
+ file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
+ tm = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file)))
+ return file, tm, num, file
def enable_page_buttons():
return gradio.update(visible=True)
+def change_dir(img_dir, date_to):
+ warning = None
+ try:
+ if os.path.exists(img_dir):
+ try:
+ f = os.listdir(img_dir)
+ except:
+ warning = f"'{img_dir} is not a directory"
+ else:
+ warning = "The directory is not exist"
+ except:
+ warning = "The format of the directory is incorrect"
+ if warning is None:
+ today = time.strftime("%Y%m%d",time.localtime(time.time()))
+ return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today
+ else:
+ return gradio.update(visible=True), gradio.update(visible=False), warning, date_to
+
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
+ custom_dir = False
if tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples
elif tabname == "img2img":
@@ -229,69 +281,85 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
dir_name = opts.outdir_extras_samples
elif tabname == "saved":
dir_name = opts.outdir_save
+ else:
+ custom_dir = True
+ dir_name = None
+
+ if not custom_dir:
+ d = dir_name.split("/")
+ dir_name = d[0]
+ for p in d[1:]:
+ dir_name = os.path.join(dir_name, p)
+ if not os.path.exists(dir_name):
+ os.makedirs(dir_name)
- d = dir_name.split("/")
- dir_name = d[0]
- for p in d[1:]:
- dir_name = os.path.join(dir_name, p)
- if not os.path.exists(dir_name):
- os.makedirs(dir_name)
-
- with gr.Column() as page_panel:
- with gr.Row(visible=False) as turn_page_buttons:
- renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
- first_page = gr.Button('First Page')
- prev_page = gr.Button('Prev Page')
- page_index = gr.Number(value=1, label="Page Index")
- next_page = gr.Button('Next Page')
- end_page = gr.Button('End Page')
-
- with gr.Row(elem_id=tabname + "_images_history"):
- with gr.Column(scale=2):
- with gr.Row():
- newest = gr.Button('Reload', elem_id=tabname + "_images_history_start")
- date_from = gr.Textbox(label="Date from", interactive=False)
- date_to = gr.Dropdown(label="Date to")
-
- history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
- with gr.Row():
- delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
- delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
-
- with gr.Column():
- with gr.Row():
- if tabname != "saved":
- save_btn = gr.Button('Save')
- pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
- pnginfo_send_to_img2img = gr.Button('Send to img2img')
- with gr.Row():
- with gr.Column():
- img_file_info = gr.Textbox(label="Generate Info", interactive=False)
- img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
+ with gr.Column() as page_panel:
+ with gr.Row():
+ img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory")
+ with gr.Row(visible=False) as warning:
+ warning_box = gr.Textbox("Message", interactive=False)
+ with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
+ with gr.Column(scale=2):
+ with gr.Row():
+ backward = gr.Button('Backward')
+ date_to = gr.Dropdown(label="Date to")
+ forward = gr.Button('Forward')
+ newest = gr.Button('Reload', elem_id=tabname + "_images_history_start")
+ with gr.Row():
+ load_info = gr.Textbox(show_label=False, interactive=False)
+ with gr.Row(visible=False) as turn_page_buttons:
+ renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
+ first_page = gr.Button('First Page')
+ prev_page = gr.Button('Prev Page')
+ page_index = gr.Number(value=1, label="Page Index")
+ next_page = gr.Button('Next Page')
+ end_page = gr.Button('End Page')
+
+ history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num)
+ with gr.Row():
+ delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
+ delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
- # hiden items
- with gr.Row(visible=False):
- visible_img_num = gr.Number()
- img_path = gr.Textbox(dir_name)
- tabname_box = gr.Textbox(tabname)
- image_index = gr.Textbox(value=-1)
- set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
- filenames = gr.State()
- all_images_list = gr.State()
- hidden = gr.Image(type="pil")
- info1 = gr.Textbox()
- info2 = gr.Textbox()
+ with gr.Column():
+ with gr.Row():
+ if tabname != "saved":
+ save_btn = gr.Button('Save')
+ pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
+ pnginfo_send_to_img2img = gr.Button('Send to img2img')
+ with gr.Row():
+ with gr.Column():
+ img_file_info = gr.Textbox(label="Generate Info", interactive=False)
+ img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
+ img_file_time= gr.Textbox(value="", label="Create Time", interactive=False)
-
+
+ # hiden items
+ with gr.Row(): #visible=False):
+ visible_img_num = gr.Number()
+ date_to_recorder = gr.State([])
+ last_date_from = gr.Textbox()
+ tabname_box = gr.Textbox(tabname)
+ image_index = gr.Textbox(value=-1)
+ set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
+ filenames = gr.State()
+ all_images_list = gr.State()
+ hidden = gr.Image(type="pil")
+ info1 = gr.Textbox()
+ info2 = gr.Textbox()
+
+ img_path.submit(change_dir, inputs=[img_path, date_to], outputs=[warning, main_panel, warning_box, date_to])
#change date
- change_date_output = [date_to, date_from, filenames, page_index, history_gallery, img_file_name, visible_img_num]
- newest.click(newest_click, inputs=[img_path, date_to], outputs=change_date_output)
- date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output)
- newest.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
- date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
- date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
- newest.click(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
+ change_date_output = [date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from]
+
+ date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output)
+ date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
+ date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+
+ newest.click(newest_click, inputs=[date_to], outputs=[date_to, date_to_recorder])
+ forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder])
+ backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder])
+
#delete
delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
@@ -301,7 +369,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
#turn page
gallery_inputs = [page_index, filenames]
- gallery_outputs = [page_index, history_gallery, img_file_name, visible_img_num]
+ gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num]
first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
@@ -317,12 +385,14 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
# other funcitons
- set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden])
+ set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden])
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
+
+
def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict):
global opts;
opts = sys_opts
@@ -330,10 +400,11 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict):
num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs:
- for tab in ["txt2img", "img2img", "extras", "saved"]:
+ for tab in [browser_tabname, "txt2img", "img2img", "extras", "saved"]:
with gr.Tab(tab):
- with gr.Blocks(analytics_enabled=False) as images_history_img2img:
+ with gr.Blocks(analytics_enabled=False) :
show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
- gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False)
+ #gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False)
+ gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_finish_render", visible=False)
return images_history
diff --git a/modules/shared.py b/modules/shared.py
index c2ea4186..1811018d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -309,10 +309,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
-options_templates.update(options_section(('images-history', "Images history"), {
- "images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
+options_templates.update(options_section(('images-history', "Images Browser"), {
+ #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
"images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"),
- "images_history_pages_num": OptionInfo(6, "Maximum number of pages per load "),
+ "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "),
+ "images_history_grid_num": OptionInfo(6, "Number of grids in each row"),
}))
diff --git a/modules/ui.py b/modules/ui.py
index 43dc88fc..85abac4d 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1548,7 +1548,7 @@ Requested path was: {f}
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
- (images_history, "History", "images_history"),
+ (images_history, "Image Browser", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
(settings_interface, "Settings", "settings"),
--
cgit v1.2.3
From 433a7525c1f5eb5963340e0cc45d31038ede3f7e Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 18 Oct 2022 15:18:02 +0300
Subject: remove shared option for update check (because it is not an argument
of webui) have launch.py examine both COMMANDLINE_ARGS as well as argv for
its arguments
---
modules/shared.py | 1 -
1 file changed, 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 50dc46ae..c0d87168 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -76,7 +76,6 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
-parser.add_argument("--update-check", action='store_true', help="enable http check to confirm that the currently running version is the most recent release.", default=False)
cmd_opts = parser.parse_args()
restricted_opts = [
--
cgit v1.2.3
From 2f448d97a9427f9a7bad19cf608561b2878ab2da Mon Sep 17 00:00:00 2001
From: w-e-w <40751091+w-e-w@users.noreply.github.com>
Date: Mon, 17 Oct 2022 23:18:21 +0900
Subject: styles.csv encoding utf8 to utf-8-sig
utf-8-bom for better compatibility for some programs
---
modules/styles.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/styles.py b/modules/styles.py
index d44dfc1a..3bf5c5b6 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -45,7 +45,7 @@ class StyleDatabase:
if not os.path.exists(path):
return
- with open(path, "r", encoding="utf8", newline='') as file:
+ with open(path, "r", encoding="utf-8-sig", newline='') as file:
reader = csv.DictReader(file)
for row in reader:
# Support loading old CSV format with "name, text"-columns
@@ -79,7 +79,7 @@ class StyleDatabase:
def save_styles(self, path: str) -> None:
# Write to temporary file first, so we don't nuke the file if something goes wrong
fd, temp_path = tempfile.mkstemp(".csv")
- with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
+ with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
--
cgit v1.2.3
From e20b7e30fe17744acb74ad33c87c0963525ea921 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 18 Oct 2022 15:33:24 +0300
Subject: fix for add difference model merging
---
modules/extras.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index c908b43e..03f6085e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -216,8 +216,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
if theta_func1:
for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key:
- t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
- theta_1[key] = theta_func1(theta_1[key], t2)
+ if key in theta_2:
+ t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
+ theta_1[key] = theta_func1(theta_1[key], t2)
+ else:
+ theta_1[key] = 0
del theta_2, teritary_model
for key in tqdm.tqdm(theta_0.keys()):
--
cgit v1.2.3
From ec1924ee5789b72c31c65932b549c59ccae0cdd6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 18 Oct 2022 16:05:52 +0300
Subject: additional fix for difference model merging
---
modules/extras.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index 03f6085e..b853fa5b 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -220,7 +220,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
theta_1[key] = theta_func1(theta_1[key], t2)
else:
- theta_1[key] = 0
+ theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2, teritary_model
for key in tqdm.tqdm(theta_0.keys()):
--
cgit v1.2.3
From b7e78ef692fe912916de6e54f6e2521b000d650c Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Tue, 18 Oct 2022 22:21:54 +0800
Subject: Image browser improve
---
modules/images_history.py | 43 ++++++++++++++++++++++---------------------
1 file changed, 22 insertions(+), 21 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index d56f3a25..a40cdc0e 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -100,14 +100,15 @@ def auto_sorting(dir_name):
date_list.append(today)
return sorted(date_list, reverse=True)
-def archive_images(dir_name, date_to):
-
+def archive_images(dir_name, date_to):
filenames = []
- loads_num =int(opts.images_history_num_per_page * opts.images_history_pages_num)
+ batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num)
+ if batch_size <= 0:
+ batch_size = opts.images_history_num_per_page * 6
today = time.strftime("%Y%m%d",time.localtime(time.time()))
date_to = today if date_to is None or date_to == "" else date_to
date_to_bak = date_to
- if opts.images_history_reconstruct_directory:
+ if False: #opts.images_history_reconstruct_directory:
date_list = auto_sorting(dir_name)
for date in date_list:
if date <= date_to:
@@ -115,11 +116,13 @@ def archive_images(dir_name, date_to):
if date == today and not os.path.exists(path):
continue
filenames = traverse_all_files(path, filenames)
- if len(filenames) > loads_num:
+ if len(filenames) > batch_size:
break
filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file))
else:
- filenames = traverse_all_files(dir_name, filenames)
+ filenames = traverse_all_files(dir_name, filenames)
+ total_num = len(filenames)
+ batch_count = len(filenames) + 1 // batch_size + 1
tmparray = [(os.path.getmtime(file), file) for file in filenames ]
date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
filenames = []
@@ -132,8 +135,8 @@ def archive_images(dir_name, date_to):
filenames.append((t, f ,date))
date_list = sorted(list(date_list.keys()), reverse=True)
sort_array = sorted(filenames, key=lambda x:-x[0])
- if len(sort_array) > loads_num:
- date = sort_array[loads_num][2]
+ if len(sort_array) > batch_size:
+ date = sort_array[batch_size][2]
filenames = [x[1] for x in sort_array]
else:
date = date_to if len(sort_array) == 0 else sort_array[-1][2]
@@ -141,9 +144,9 @@ def archive_images(dir_name, date_to):
filenames = [x[1] for x in sort_array if x[2]>= date]
num = len(filenames)
last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
- date = date[:4] + "-" + date[4:6] + "-" + date[6:8]
- date_to_bak = date_to_bak[:4] + "-" + date_to_bak[4:6] + "-" + date_to_bak[6:8]
- load_info = f"Loaded {(num + 1) // opts.images_history_pages_num} pades, {num} images, during {date} - {date_to_bak}"
+ date = date[:4] + "/" + date[4:6] + "/" + date[6:8]
+ date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8]
+ load_info = f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
_, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
return (
gradio.Dropdown.update(choices=date_list, value=date_to),
@@ -154,12 +157,10 @@ def archive_images(dir_name, date_to):
"",
"",
visible_num,
- last_date_from
+ last_date_from,
+ #gradio.update(visible=batch_count > 1)
)
-
-
-
def delete_image(delete_num, name, filenames, image_index, visible_num):
if name == "":
return filenames, delete_num
@@ -295,16 +296,16 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
with gr.Column() as page_panel:
with gr.Row():
- img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory")
+ img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
with gr.Row(visible=False) as warning:
warning_box = gr.Textbox("Message", interactive=False)
with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
with gr.Column(scale=2):
- with gr.Row():
- backward = gr.Button('Backward')
- date_to = gr.Dropdown(label="Date to")
- forward = gr.Button('Forward')
+ with gr.Row() as batch_panel:
+ forward = gr.Button('Forward')
+ date_to = gr.Dropdown(label="Date to")
+ backward = gr.Button('Backward')
newest = gr.Button('Reload', elem_id=tabname + "_images_history_start")
with gr.Row():
load_info = gr.Textbox(show_label=False, interactive=False)
@@ -335,7 +336,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
# hiden items
- with gr.Row(): #visible=False):
+ with gr.Row(visible=False):
visible_img_num = gr.Number()
date_to_recorder = gr.State([])
last_date_from = gr.Textbox()
--
cgit v1.2.3
From cbf15edbf90a68a08eeab40af5df577ba4ac90b6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 18 Oct 2022 17:23:38 +0300
Subject: remove dependence on TQDM for sampler progress/interrupt
functionality
---
modules/processing.py | 6 ---
modules/sd_samplers.py | 107 +++++++++++++++++++++++++++----------------------
2 files changed, 58 insertions(+), 55 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index deb6125e..346eea88 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -402,12 +402,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
- if state.interrupted or state.skipped:
-
- # if we are interrupted, sample returns just noise
- # use the image collected previously in sampler loop
- samples_ddim = shared.state.current_latent
-
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 20309e06..b58e810b 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -98,25 +98,8 @@ def store_latent(decoded):
shared.state.current_image = sample_to_image(decoded)
-
-def extended_tdqm(sequence, *args, desc=None, **kwargs):
- state.sampling_steps = len(sequence)
- state.sampling_step = 0
-
- seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
-
- for x in seq:
- if state.interrupted or state.skipped:
- break
-
- yield x
-
- state.sampling_step += 1
- shared.total_tqdm.update()
-
-
-ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
-ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
+class InterruptedException(BaseException):
+ pass
class VanillaStableDiffusionSampler:
@@ -128,14 +111,32 @@ class VanillaStableDiffusionSampler:
self.init_latent = None
self.sampler_noises = None
self.step = 0
+ self.stop_at = None
self.eta = None
self.default_eta = 0.0
self.config = None
+ self.last_latent = None
def number_of_needed_noises(self, p):
return 0
+ def launch_sampling(self, steps, func):
+ state.sampling_steps = steps
+ state.sampling_step = 0
+
+ try:
+ return func()
+ except InterruptedException:
+ return self.last_latent
+
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ if state.interrupted or state.skipped:
+ raise InterruptedException
+
+ if self.stop_at is not None and self.step > self.stop_at:
+ raise InterruptedException
+
+
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
@@ -159,11 +160,16 @@ class VanillaStableDiffusionSampler:
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None:
- store_latent(self.init_latent * self.mask + self.nmask * res[1])
+ self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
else:
- store_latent(res[1])
+ self.last_latent = res[1]
+
+ store_latent(self.last_latent)
self.step += 1
+ state.sampling_step = self.step
+ shared.total_tqdm.update()
+
return res
def initialize(self, p):
@@ -192,7 +198,7 @@ class VanillaStableDiffusionSampler:
self.init_latent = x
self.step = 0
- samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
+ samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
@@ -206,9 +212,9 @@ class VanillaStableDiffusionSampler:
# existing code fails with certain step counts, like 9
try:
- samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
except Exception:
- samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim
@@ -223,6 +229,9 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale):
+ if state.interrupted or state.skipped:
+ raise InterruptedException
+
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
@@ -268,25 +277,6 @@ class CFGDenoiser(torch.nn.Module):
return denoised
-def extended_trange(sampler, count, *args, **kwargs):
- state.sampling_steps = count
- state.sampling_step = 0
-
- seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
-
- for x in seq:
- if state.interrupted or state.skipped:
- break
-
- if sampler.stop_at is not None and x > sampler.stop_at:
- break
-
- yield x
-
- state.sampling_step += 1
- shared.total_tqdm.update()
-
-
class TorchHijack:
def __init__(self, kdiff_sampler):
self.kdiff_sampler = kdiff_sampler
@@ -314,9 +304,28 @@ class KDiffusionSampler:
self.eta = None
self.default_eta = 1.0
self.config = None
+ self.last_latent = None
def callback_state(self, d):
- store_latent(d["denoised"])
+ step = d['i']
+ latent = d["denoised"]
+ store_latent(latent)
+ self.last_latent = latent
+
+ if self.stop_at is not None and step > self.stop_at:
+ raise InterruptedException
+
+ state.sampling_step = step
+ shared.total_tqdm.update()
+
+ def launch_sampling(self, steps, func):
+ state.sampling_steps = steps
+ state.sampling_step = 0
+
+ try:
+ return func()
+ except InterruptedException:
+ return self.last_latent
def number_of_needed_noises(self, p):
return p.steps
@@ -339,9 +348,6 @@ class KDiffusionSampler:
self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral
- if hasattr(k_diffusion.sampling, 'trange'):
- k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
-
if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
@@ -383,8 +389,9 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
- return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps
@@ -406,6 +413,8 @@ class KDiffusionSampler:
extra_params_kwargs['n'] = steps
else:
extra_params_kwargs['sigmas'] = sigmas
- samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
+
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
return samples
--
cgit v1.2.3
From 6021f7a75f7b5208a2be15cda5526028152f922d Mon Sep 17 00:00:00 2001
From: discus0434
Date: Wed, 19 Oct 2022 00:51:36 +0900
Subject: add options to custom hypernetwork layer structure
---
modules/hypernetworks/hypernetwork.py | 88 ++++++++++++++++++++++++++---------
modules/shared.py | 4 +-
2 files changed, 70 insertions(+), 22 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 4905710e..cadb9911 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,52 +1,98 @@
+import csv
import datetime
import glob
import html
import os
import sys
import traceback
-import tqdm
-import csv
+import modules.textual_inversion.dataset
import torch
-
-from ldm.util import default
-from modules import devices, shared, processing, sd_models
-import torch
-from torch import einsum
+import tqdm
from einops import rearrange, repeat
-import modules.textual_inversion.dataset
+from ldm.util import default
+from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from torch import einsum
+
+
+def parse_layer_structure(dim, state_dict):
+ i = 0
+ res = [1]
+ while (key := "linear.{}.weight".format(i)) in state_dict:
+ weight = state_dict[key]
+ res.append(len(weight) // dim)
+ i += 1
+ return res
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
+ layer_structure = None
+ add_layer_norm = False
def __init__(self, dim, state_dict=None):
super().__init__()
+ if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None:
+ layer_structure = (1, 2, 1)
+ else:
+ if self.layer_structure is not None:
+ assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
+ layer_structure = self.layer_structure
+ else:
+ layer_structure = parse_layer_structure(dim, state_dict)
+
+ linears = []
+ for i in range(len(layer_structure) - 1):
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+ if self.add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
- self.linear1 = torch.nn.Linear(dim, dim * 2)
- self.linear2 = torch.nn.Linear(dim * 2, dim)
+ self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
- self.load_state_dict(state_dict, strict=True)
+ try:
+ self.load_state_dict(state_dict)
+ except RuntimeError:
+ self.try_load_previous(state_dict)
else:
-
- self.linear1.weight.data.normal_(mean=0.0, std=0.01)
- self.linear1.bias.data.zero_()
- self.linear2.weight.data.normal_(mean=0.0, std=0.01)
- self.linear2.bias.data.zero_()
+ for layer in self.linear:
+ layer.weight.data.normal_(mean = 0.0, std = 0.01)
+ layer.bias.data.zero_()
self.to(devices.device)
+ def try_load_previous(self, state_dict):
+ states = self.state_dict()
+ states['linear.0.bias'].copy_(state_dict['linear1.bias'])
+ states['linear.0.weight'].copy_(state_dict['linear1.weight'])
+ states['linear.1.bias'].copy_(state_dict['linear2.bias'])
+ states['linear.1.weight'].copy_(state_dict['linear2.weight'])
+
def forward(self, x):
- return x + (self.linear2(self.linear1(x))) * self.multiplier
+ return x + self.linear(x) * self.multiplier
+
+ def trainables(self):
+ res = []
+ for layer in self.linear:
+ res += [layer.weight, layer.bias]
+ return res
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
+def apply_layer_structure(value=None):
+ HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure
+
+
+def apply_layer_norm(value=None):
+ HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm
+
+
class Hypernetwork:
filename = None
name = None
@@ -68,7 +114,7 @@ class Hypernetwork:
for k, layers in self.layers.items():
for layer in layers:
layer.train()
- res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
+ res += layer.trainables()
return res
@@ -226,7 +272,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
-
+ assert ds.length > 1, "Dataset should contain more than 1 images"
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
@@ -261,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
-# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x
@@ -283,7 +329,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{mean_loss:.7f}",
- "learn_rate": scheduler.learn_rate
+ "learn_rate": f"{scheduler.learn_rate:.7f}"
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
diff --git a/modules/shared.py b/modules/shared.py
index c0d87168..c87ce70e 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,7 +13,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers, sd_models, localization
+from modules import sd_models, sd_samplers, localization
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@@ -258,6 +258,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}),
+ "sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
--
cgit v1.2.3
From a5611ea5026bd8e12d8e84023384c369d0511dda Mon Sep 17 00:00:00 2001
From: discus0434
Date: Wed, 19 Oct 2022 01:00:01 +0900
Subject: update
---
modules/hypernetworks/hypernetwork.py | 14 ++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index cadb9911..c5835bce 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,20 +1,22 @@
-import csv
import datetime
import glob
import html
import os
import sys
import traceback
+import tqdm
+import csv
-import modules.textual_inversion.dataset
import torch
-import tqdm
-from einops import rearrange, repeat
+
from ldm.util import default
-from modules import devices, processing, sd_models, shared
+from modules import devices, shared, processing, sd_models
+import torch
+from torch import einsum
+from einops import rearrange, repeat
+import modules.textual_inversion.dataset
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
-from torch import einsum
def parse_layer_structure(dim, state_dict):
--
cgit v1.2.3
From 7f2095c6c8db82a5c9cd7c7177f6ba856a2cc676 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Wed, 19 Oct 2022 01:01:22 +0900
Subject: update
---
modules/shared.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index c87ce70e..6b6d5c41 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,7 +13,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_models, sd_samplers, localization
+from modules import sd_samplers, sd_models, localization
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@@ -135,7 +135,7 @@ class State:
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
-
+
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
--
cgit v1.2.3
From e40ba281f1b419cf99552962ea01d87d699840a5 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Wed, 19 Oct 2022 01:03:58 +0900
Subject: update
---
modules/hypernetworks/hypernetwork.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index c5835bce..082165f4 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -309,7 +309,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
- c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x
@@ -331,7 +331,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{mean_loss:.7f}",
- "learn_rate": f"{scheduler.learn_rate:.7f}"
+ "learn_rate": scheduler.learn_rate
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
--
cgit v1.2.3
From e7f4808505f7a6339927c32b9a0c01bc9134bdeb Mon Sep 17 00:00:00 2001
From: arcticfaded
Date: Tue, 18 Oct 2022 19:04:56 +0000
Subject: provide sampler by name
---
modules/api/api.py | 12 ++++++++++--
modules/api/processing.py | 16 ++++++++++++++--
2 files changed, 24 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index ce98cb8c..ff9df0d1 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,14 +1,17 @@
from modules.api.processing import StableDiffusionProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
+from modules.sd_samplers import samplers_k_diffusion
import modules.shared as shared
import uvicorn
-from fastapi import Body, APIRouter
+from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
import json
import io
import base64
+sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None)
+
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
@@ -23,9 +26,14 @@ class Api:
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
+ sampler_index = sampler_to_index(txt2imgreq.sampler_index)
+
+ if sampler_index is None:
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
- "sampler_index": 0,
+ "sampler_index": sampler_index[0],
"do_not_save_samples": True,
"do_not_save_grid": True
}
diff --git a/modules/api/processing.py b/modules/api/processing.py
index b6798241..2e6483ee 100644
--- a/modules/api/processing.py
+++ b/modules/api/processing.py
@@ -42,7 +42,8 @@ class PydanticModelGenerator:
def __init__(
self,
model_name: str = None,
- class_instance = None
+ class_instance = None,
+ additional_fields = None,
):
def field_type_generator(k, v):
# field_type = str if not overrides.get(k) else overrides[k]["type"]
@@ -70,6 +71,13 @@ class PydanticModelGenerator:
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
+
+ for fields in additional_fields:
+ self._model_def.append(ModelDef(
+ field=underscore(fields["key"]),
+ field_alias=fields["key"],
+ field_type=fields["type"],
+ field_value=fields["default"]))
def generate_model(self):
"""
@@ -84,4 +92,8 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_mutation = True
return DynamicModel
-StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model()
+StableDiffusionProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingTxt2Img",
+ StableDiffusionProcessingTxt2Img,
+ [{"key": "sampler_index", "type": str, "default": "k_euler_a"}]
+).generate_model()
--
cgit v1.2.3
From 538bc89c269743e56b07ef2b471d1ce0a39b6776 Mon Sep 17 00:00:00 2001
From: yfszzx
Date: Wed, 19 Oct 2022 11:27:51 +0800
Subject: Image browser improved
---
modules/images_history.py | 135 +++++++++++++++++++++++++---------------------
modules/shared.py | 5 ++
modules/ui.py | 2 +-
3 files changed, 80 insertions(+), 62 deletions(-)
(limited to 'modules')
diff --git a/modules/images_history.py b/modules/images_history.py
index a40cdc0e..78fd0543 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -4,7 +4,9 @@ import time
import hashlib
import gradio
system_bak_path = "webui_log_and_bak"
-browser_tabname = "custom"
+custom_tab_name = "custom fold"
+faverate_tab_name = "favorites"
+tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name]
def is_valid_date(date):
try:
time.strptime(date, "%Y%m%d")
@@ -122,7 +124,6 @@ def archive_images(dir_name, date_to):
else:
filenames = traverse_all_files(dir_name, filenames)
total_num = len(filenames)
- batch_count = len(filenames) + 1 // batch_size + 1
tmparray = [(os.path.getmtime(file), file) for file in filenames ]
date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
filenames = []
@@ -146,10 +147,12 @@ def archive_images(dir_name, date_to):
last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
date = date[:4] + "/" + date[4:6] + "/" + date[6:8]
date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8]
- load_info = f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
+ load_info = ""
+ load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
+ load_info += "
"
_, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
return (
- gradio.Dropdown.update(choices=date_list, value=date_to),
+ date_to,
load_info,
filenames,
1,
@@ -158,7 +161,7 @@ def archive_images(dir_name, date_to):
"",
visible_num,
last_date_from,
- #gradio.update(visible=batch_count > 1)
+ gradio.update(visible=total_num > num)
)
def delete_image(delete_num, name, filenames, image_index, visible_num):
@@ -209,7 +212,7 @@ def get_recent_images(page_index, step, filenames):
visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
return page_index, image_list, "", "", visible_num
-def newest_click(date_to):
+def loac_batch_click(date_to):
if date_to is None:
return time.strftime("%Y%m%d",time.localtime(time.time())), []
else:
@@ -248,7 +251,7 @@ def page_index_change(page_index, filenames):
def show_image_info(tabname_box, num, page_index, filenames):
file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
- tm = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file)))
+ tm = "" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "
"
return file, tm, num, file
def enable_page_buttons():
@@ -268,9 +271,9 @@ def change_dir(img_dir, date_to):
warning = "The format of the directory is incorrect"
if warning is None:
today = time.strftime("%Y%m%d",time.localtime(time.time()))
- return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today
+ return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True)
else:
- return gradio.update(visible=True), gradio.update(visible=False), warning, date_to
+ return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False)
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
custom_dir = False
@@ -280,7 +283,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
dir_name = opts.outdir_img2img_samples
elif tabname == "extras":
dir_name = opts.outdir_extras_samples
- elif tabname == "saved":
+ elif tabname == faverate_tab_name:
dir_name = opts.outdir_save
else:
custom_dir = True
@@ -295,22 +298,26 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
os.makedirs(dir_name)
with gr.Column() as page_panel:
- with gr.Row():
- img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
+ with gr.Row():
+ with gr.Column(scale=1, visible=not custom_dir) as load_batch_box:
+ load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True)
+ with gr.Column(scale=4):
+ with gr.Row():
+ img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
+ with gr.Row():
+ with gr.Column(visible=False, scale=1) as batch_panel:
+ with gr.Row():
+ forward = gr.Button('Prev batch')
+ backward = gr.Button('Next batch')
+ with gr.Column(scale=3):
+ load_info = gr.HTML(visible=not custom_dir)
with gr.Row(visible=False) as warning:
warning_box = gr.Textbox("Message", interactive=False)
with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
- with gr.Column(scale=2):
- with gr.Row() as batch_panel:
- forward = gr.Button('Forward')
- date_to = gr.Dropdown(label="Date to")
- backward = gr.Button('Backward')
- newest = gr.Button('Reload', elem_id=tabname + "_images_history_start")
- with gr.Row():
- load_info = gr.Textbox(show_label=False, interactive=False)
- with gr.Row(visible=False) as turn_page_buttons:
- renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
+ with gr.Column(scale=2):
+ with gr.Row(visible=True) as turn_page_buttons:
+ #date_to = gr.Dropdown(label="Date to")
first_page = gr.Button('First Page')
prev_page = gr.Button('Prev Page')
page_index = gr.Number(value=1, label="Page Index")
@@ -322,50 +329,54 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
- with gr.Column():
- with gr.Row():
- if tabname != "saved":
- save_btn = gr.Button('Save')
- pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
- pnginfo_send_to_img2img = gr.Button('Send to img2img')
+ with gr.Column():
with gr.Row():
with gr.Column():
- img_file_info = gr.Textbox(label="Generate Info", interactive=False)
+ img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6)
+ gr.HTML("
")
img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
- img_file_time= gr.Textbox(value="", label="Create Time", interactive=False)
-
+ img_file_time= gr.HTML()
+ with gr.Row():
+ if tabname != faverate_tab_name:
+ save_btn = gr.Button('Collect')
+ pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
+ pnginfo_send_to_img2img = gr.Button('Send to img2img')
+
- # hiden items
- with gr.Row(visible=False):
- visible_img_num = gr.Number()
- date_to_recorder = gr.State([])
- last_date_from = gr.Textbox()
- tabname_box = gr.Textbox(tabname)
- image_index = gr.Textbox(value=-1)
- set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
- filenames = gr.State()
- all_images_list = gr.State()
- hidden = gr.Image(type="pil")
- info1 = gr.Textbox()
- info2 = gr.Textbox()
-
- img_path.submit(change_dir, inputs=[img_path, date_to], outputs=[warning, main_panel, warning_box, date_to])
- #change date
- change_date_output = [date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from]
+ # hiden items
+ with gr.Row(visible=False):
+ renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
+ batch_date_to = gr.Textbox(label="Date to")
+ visible_img_num = gr.Number()
+ date_to_recorder = gr.State([])
+ last_date_from = gr.Textbox()
+ tabname_box = gr.Textbox(tabname)
+ image_index = gr.Textbox(value=-1)
+ set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
+ filenames = gr.State()
+ all_images_list = gr.State()
+ hidden = gr.Image(type="pil")
+ info1 = gr.Textbox()
+ info2 = gr.Textbox()
+
+ img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info])
+
+ #change batch
+ change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel]
- date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output)
- date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
- date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
+ batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output)
+ batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
+ batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
- newest.click(newest_click, inputs=[date_to], outputs=[date_to, date_to_recorder])
- forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder])
- backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder])
+ load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder])
+ forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
+ backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
#delete
delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None)
- if tabname != "saved":
+ if tabname != faverate_tab_name:
save_btn.click(save_image, inputs=[img_file_name], outputs=None)
#turn page
@@ -394,18 +405,20 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
-def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict):
+def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict):
global opts;
opts = sys_opts
loads_files_num = int(opts.images_history_num_per_page)
num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
+ if cmp_ops.browse_all_images:
+ tabs_list.append(custom_tab_name)
with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs:
- for tab in [browser_tabname, "txt2img", "img2img", "extras", "saved"]:
+ for tab in tabs_list:
with gr.Tab(tab):
with gr.Blocks(analytics_enabled=False) :
- show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
- #gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False)
- gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_finish_render", visible=False)
-
+ show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
+ gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False)
+ gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False)
+
return images_history
diff --git a/modules/shared.py b/modules/shared.py
index 1811018d..4d735414 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -74,6 +74,10 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
+parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False)
+
+
+cmd_opts = parser.parse_args()
cmd_opts = parser.parse_args()
@@ -311,6 +315,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
options_templates.update(options_section(('images-history', "Images Browser"), {
#"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
+ "images_history_preload": OptionInfo(False, "Preload images at startup"),
"images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"),
"images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "),
"images_history_grid_num": OptionInfo(6, "Number of grids in each row"),
diff --git a/modules/ui.py b/modules/ui.py
index 85abac4d..88f46659 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1150,7 +1150,7 @@ def create_ui(wrap_gradio_gpu_call):
"i2i":img2img_paste_fields
}
- images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
+ images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
--
cgit v1.2.3
From 0f0d6ab8e06898ce066251fc769fe14e77e98ced Mon Sep 17 00:00:00 2001
From: arcticfaded
Date: Wed, 19 Oct 2022 05:19:01 +0000
Subject: call sampler by name
---
modules/api/api.py | 11 ++++++-----
modules/api/processing.py | 6 +++---
2 files changed, 9 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index ff9df0d1..5b0c934e 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,6 +1,7 @@
from modules.api.processing import StableDiffusionProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
-from modules.sd_samplers import samplers_k_diffusion
+from modules.sd_samplers import all_samplers
+from modules.extras import run_pnginfo
import modules.shared as shared
import uvicorn
from fastapi import Body, APIRouter, HTTPException
@@ -10,7 +11,7 @@ import json
import io
import base64
-sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None)
+sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
@@ -53,13 +54,13 @@ class Api:
- def img2imgendoint(self):
+ def img2imgapi(self):
raise NotImplementedError
- def extrasendoint(self):
+ def extrasapi(self):
raise NotImplementedError
- def pnginfoendoint(self):
+ def pnginfoapi(self):
raise NotImplementedError
def launch(self, server_name, port):
diff --git a/modules/api/processing.py b/modules/api/processing.py
index 2e6483ee..4c541241 100644
--- a/modules/api/processing.py
+++ b/modules/api/processing.py
@@ -1,7 +1,7 @@
from inflection import underscore
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.processing import StableDiffusionProcessingTxt2Img
import inspect
@@ -95,5 +95,5 @@ class PydanticModelGenerator:
StableDiffusionProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "k_euler_a"}]
-).generate_model()
+ [{"key": "sampler_index", "type": str, "default": "Euler"}]
+).generate_model()
\ No newline at end of file
--
cgit v1.2.3
From 10aca1ca3e81e69e08f556a500c3dc603451429b Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 19 Oct 2022 08:42:22 +0300
Subject: more careful loading of model weights (eliminates some issues with
checkpoints that have weird cond_stage_model layer names)
---
modules/sd_models.py | 28 +++++++++++++++++++++++++---
1 file changed, 25 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 3aa21ec1..7ad6d474 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -122,11 +122,33 @@ def select_checkpoint():
return checkpoint_info
+chckpoint_dict_replacements = {
+ 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
+ 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
+ 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
+}
+
+
+def transform_checkpoint_dict_key(k):
+ for text, replacement in chckpoint_dict_replacements.items():
+ if k.startswith(text):
+ k = replacement + k[len(text):]
+
+ return k
+
+
def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd:
- return pl_sd["state_dict"]
+ pl_sd = pl_sd["state_dict"]
+
+ sd = {}
+ for k, v in pl_sd.items():
+ new_key = transform_checkpoint_dict_key(k)
+
+ if new_key is not None:
+ sd[new_key] = v
- return pl_sd
+ return sd
def load_model_weights(model, checkpoint_info):
@@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
- model.load_state_dict(sd, strict=False)
+ missing, extra = model.load_state_dict(sd, strict=False)
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
--
cgit v1.2.3
From da72becb13e4b750fbcb3d158c3f843311ef9938 Mon Sep 17 00:00:00 2001
From: Silent <16026653+s-ilent@users.noreply.github.com>
Date: Wed, 19 Oct 2022 16:14:33 +1030
Subject: Use training width/height when training hypernetworks.
---
modules/hypernetworks/hypernetwork.py | 4 ++--
modules/ui.py | 2 ++
2 files changed, 4 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 4905710e..b8695fc1 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -196,7 +196,7 @@ def stack_conds(conds):
return torch.stack(conds)
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -225,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
diff --git a/modules/ui.py b/modules/ui.py
index fb6eb5a0..ca46343f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1341,6 +1341,8 @@ def create_ui(wrap_gradio_gpu_call):
batch_size,
dataset_directory,
log_directory,
+ training_width,
+ training_height,
steps,
create_image_every,
save_embedding_every,
--
cgit v1.2.3
From 2fd7935ef4ed296db5dfd8c7fea99244816f8cf0 Mon Sep 17 00:00:00 2001
From: Cheka
Date: Tue, 18 Oct 2022 20:28:28 -0300
Subject: Remove wrong self reference in CUDA support for invokeai
---
modules/sd_hijack_optimizations.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index a3345bb9..98123fbf 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v):
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation
- return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
+ return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def einsum_op(q, k, v):
if q.device.type == 'cuda':
--
cgit v1.2.3
From bcfbb33e50a48b237d8d961cc2be038db53774d5 Mon Sep 17 00:00:00 2001
From: Anastasius
Date: Mon, 17 Oct 2022 13:35:20 -0700
Subject: Added time left estimation
---
modules/ui.py | 14 +++++++++++++-
1 file changed, 13 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index ca46343f..9a54aa16 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -261,6 +261,15 @@ def wrap_gradio_call(func, extra_outputs=None):
return f
+def calc_time_left(progress):
+ if progress == 0:
+ return "N/A"
+ else:
+ time_since_start = time.time() - shared.state.time_start
+ eta = (time_since_start/progress)
+ return time.strftime('%H:%M:%S', time.gmtime(eta-time_since_start))
+
+
def check_progress_call(id_part):
if shared.state.job_count == 0:
return "", gr_show(False), gr_show(False), gr_show(False)
@@ -272,11 +281,13 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+ time_left = calc_time_left( progress )
+
progress = min(progress, 1)
progressbar = ""
if opts.show_progressbar:
- progressbar = f"""{str(int(progress*100))+"%" if progress > 0.01 else ""}
"""
+ progressbar = f"""{str(int(progress*100))+"% ETA:"+time_left if progress > 0.01 else ""}
"""
image = gr_show(False)
preview_visibility = gr_show(False)
@@ -308,6 +319,7 @@ def check_progress_call_initial(id_part):
shared.state.current_latent = None
shared.state.current_image = None
shared.state.textinfo = None
+ shared.state.time_start = time.time()
return check_progress_call(id_part)
--
cgit v1.2.3
From 442dbedc159bb7e9cf94f0c3626f8a409e0a50eb Mon Sep 17 00:00:00 2001
From: Anastasius
Date: Tue, 18 Oct 2022 10:38:07 -0700
Subject: Estimated time displayed if jobs take more 60 sec
---
modules/ui.py | 17 ++++++++++++-----
1 file changed, 12 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 9a54aa16..fa54110b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -261,13 +261,17 @@ def wrap_gradio_call(func, extra_outputs=None):
return f
-def calc_time_left(progress):
+def calc_time_left(progress, threshold, label, force_display):
if progress == 0:
- return "N/A"
+ return ""
else:
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
- return time.strftime('%H:%M:%S', time.gmtime(eta-time_since_start))
+ eta_relative = eta-time_since_start
+ if eta_relative > threshold or force_display:
+ return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
+ else:
+ return ""
def check_progress_call(id_part):
@@ -281,13 +285,15 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
- time_left = calc_time_left( progress )
+ time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display )
+ if time_left != "":
+ shared.state.time_left_force_display = True
progress = min(progress, 1)
progressbar = ""
if opts.show_progressbar:
- progressbar = f"""{str(int(progress*100))+"% ETA:"+time_left if progress > 0.01 else ""}
"""
+ progressbar = f"""{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}
"""
image = gr_show(False)
preview_visibility = gr_show(False)
@@ -320,6 +326,7 @@ def check_progress_call_initial(id_part):
shared.state.current_image = None
shared.state.textinfo = None
shared.state.time_start = time.time()
+ shared.state.time_left_force_display = False
return check_progress_call(id_part)
--
cgit v1.2.3
From 1d4aa376e6111e90888a30ae24d2bcd7f978ec51 Mon Sep 17 00:00:00 2001
From: Anastasius
Date: Tue, 18 Oct 2022 12:42:39 -0700
Subject: Predictable long operation check for time estimation
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index fa54110b..38ba1138 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -268,7 +268,7 @@ def calc_time_left(progress, threshold, label, force_display):
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
- if eta_relative > threshold or force_display:
+ if (eta_relative > threshold and progress > 0.02) or force_display:
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
else:
return ""
--
cgit v1.2.3
From bb0e7232b301d1706bbd0e09367dece3bb7ac07c Mon Sep 17 00:00:00 2001
From: Ikko Ashimine
Date: Wed, 19 Oct 2022 02:18:56 +0900
Subject: Fix typo in prompt_parser.py
assoicated -> associated
---
modules/prompt_parser.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 919d5d31..f70872c4 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -275,7 +275,7 @@ re_attention = re.compile(r"""
def parse_prompt_attention(text):
"""
- Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
--
cgit v1.2.3
From f894dd552f68bea27476f1f360ab8e79f3a65b4f Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 19 Oct 2022 12:45:30 +0300
Subject: fix for broken checkpoint merger
---
modules/sd_models.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 7ad6d474..eae22e87 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -148,7 +148,10 @@ def get_state_dict_from_checkpoint(pl_sd):
if new_key is not None:
sd[new_key] = v
- return sd
+ pl_sd.clear()
+ pl_sd.update(sd)
+
+ return pl_sd
def load_model_weights(model, checkpoint_info):
--
cgit v1.2.3
From 42fbda83bb9830af18187fddb50c1bedd01da502 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Wed, 19 Oct 2022 14:30:33 +0000
Subject: layer options moves into create hnet ui
---
modules/hypernetworks/hypernetwork.py | 64 +++++++++++++++++------------------
modules/hypernetworks/ui.py | 9 +++--
modules/shared.py | 2 --
modules/ui.py | 8 +++--
4 files changed, 45 insertions(+), 38 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 583ada31..7d519cd9 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -19,37 +19,21 @@ from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
-def parse_layer_structure(dim, state_dict):
- i = 0
- res = [1]
- while (key := "linear.{}.weight".format(i)) in state_dict:
- weight = state_dict[key]
- res.append(len(weight) // dim)
- i += 1
- return res
-
-
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
- layer_structure = None
- add_layer_norm = False
- def __init__(self, dim, state_dict=None):
+ def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
super().__init__()
- if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None:
- layer_structure = (1, 2, 1)
+ if layer_structure is not None:
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
else:
- if self.layer_structure is not None:
- assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
- assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
- layer_structure = self.layer_structure
- else:
- layer_structure = parse_layer_structure(dim, state_dict)
+ layer_structure = parse_layer_structure(dim, state_dict)
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
- if self.add_layer_norm:
+ if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
self.linear = torch.nn.Sequential(*linears)
@@ -77,38 +61,47 @@ class HypernetworkModule(torch.nn.Module):
return x + self.linear(x) * self.multiplier
def trainables(self):
- res = []
+ layer_structure = []
for layer in self.linear:
- res += [layer.weight, layer.bias]
- return res
+ layer_structure += [layer.weight, layer.bias]
+ return layer_structure
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
-def apply_layer_structure(value=None):
- HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure
+def parse_layer_structure(dim, state_dict):
+ i = 0
+ layer_structure = [1]
+ while (key := "linear.{}.weight".format(i)) in state_dict:
+ weight = state_dict[key]
+ layer_structure.append(len(weight) // dim)
+ i += 1
-def apply_layer_norm(value=None):
- HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm
+ return layer_structure
class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
self.filename = None
self.name = name
self.layers = {}
self.step = 0
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.layer_structure = layer_structure
+ self.add_layer_norm = add_layer_norm
for size in enable_sizes or []:
- self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
+ self.layers[size] = (
+ HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
+ )
def weights(self):
res = []
@@ -128,6 +121,8 @@ class Hypernetwork:
state_dict['step'] = self.step
state_dict['name'] = self.name
+ state_dict['layer_structure'] = self.layer_structure
+ state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -142,10 +137,15 @@ class Hypernetwork:
for size, sd in state_dict.items():
if type(size) == int:
- self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
+ self.layers[size] = (
+ HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
+ HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
+ )
self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0)
+ self.layer_structure = state_dict.get('layer_structure', None)
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index dfa599af..7e8ea95e 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -9,11 +9,16 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name, enable_sizes):
+def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
+ name=name,
+ enable_sizes=[int(x) for x in enable_sizes],
+ layer_structure=layer_structure,
+ add_layer_norm=add_layer_norm,
+ )
hypernet.save(fn)
shared.reload_hypernetworks()
diff --git a/modules/shared.py b/modules/shared.py
index 0540cae9..faede821 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -260,8 +260,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
- "sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}),
- "sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
diff --git a/modules/ui.py b/modules/ui.py
index ca46343f..d9ee462f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -458,14 +458,14 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
)
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
)
@@ -1198,6 +1198,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
+ new_hypernetwork_layer_structure = gr.Dropdown(label="Hypernetwork layer structure", choices=[(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)])
+ new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
with gr.Row():
with gr.Column(scale=3):
@@ -1280,6 +1282,8 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
+ new_hypernetwork_layer_structure,
+ new_hypernetwork_add_layer_norm,
],
outputs=[
train_hypernetwork_name,
--
cgit v1.2.3
From 3770b8d2fa62066d472a04739c7b84bce8538832 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Wed, 19 Oct 2022 15:28:42 +0000
Subject: enable to write layer structure of hn himself
---
modules/hypernetworks/ui.py | 4 ++++
modules/ui.py | 2 +-
2 files changed, 5 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 7e8ea95e..08f75f15 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -1,5 +1,6 @@
import html
import os
+import re
import gradio as gr
@@ -13,6 +14,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
+ if type(layer_structure) == str:
+ layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure)))
+
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name,
enable_sizes=[int(x) for x in enable_sizes],
diff --git a/modules/ui.py b/modules/ui.py
index d9ee462f..18a2add0 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1198,7 +1198,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
- new_hypernetwork_layer_structure = gr.Dropdown(label="Hypernetwork layer structure", choices=[(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)])
+ new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
with gr.Row():
--
cgit v1.2.3
From 019a3a88f07766f2d32c32fbe8e41625f28ecb5e Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 19 Oct 2022 17:15:47 +0100
Subject: Update ui.py
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index d2e24880..1573ef82 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1247,7 +1247,7 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Tab(label="Train"):
- gr.HTML(value="Train an embedding; must specify a directory with a set of 1:1 ratio images
")
+ gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork wiki
")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
--
cgit v1.2.3
From c6e9fed5003631c87d548e74d6e359678959a453 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Wed, 19 Oct 2022 19:21:16 +0300
Subject: fix for #3086 failing to load any previous hypernet
---
modules/hypernetworks/hypernetwork.py | 60 ++++++++++++++++-------------------
1 file changed, 28 insertions(+), 32 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 7d519cd9..74300122 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -24,11 +24,10 @@ class HypernetworkModule(torch.nn.Module):
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
super().__init__()
- if layer_structure is not None:
- assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
- assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
- else:
- layer_structure = parse_layer_structure(dim, state_dict)
+
+ assert layer_structure is not None, "layer_structure mut not be None"
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
@@ -39,23 +38,30 @@ class HypernetworkModule(torch.nn.Module):
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
- try:
- self.load_state_dict(state_dict)
- except RuntimeError:
- self.try_load_previous(state_dict)
+ self.fix_old_state_dict(state_dict)
+ self.load_state_dict(state_dict)
else:
for layer in self.linear:
- layer.weight.data.normal_(mean = 0.0, std = 0.01)
+ layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
self.to(devices.device)
- def try_load_previous(self, state_dict):
- states = self.state_dict()
- states['linear.0.bias'].copy_(state_dict['linear1.bias'])
- states['linear.0.weight'].copy_(state_dict['linear1.weight'])
- states['linear.1.bias'].copy_(state_dict['linear2.bias'])
- states['linear.1.weight'].copy_(state_dict['linear2.weight'])
+ def fix_old_state_dict(self, state_dict):
+ changes = {
+ 'linear1.bias': 'linear.0.bias',
+ 'linear1.weight': 'linear.0.weight',
+ 'linear2.bias': 'linear.1.bias',
+ 'linear2.weight': 'linear.1.weight',
+ }
+
+ for fr, to in changes.items():
+ x = state_dict.get(fr, None)
+ if x is None:
+ continue
+
+ del state_dict[fr]
+ state_dict[to] = x
def forward(self, x):
return x + self.linear(x) * self.multiplier
@@ -71,18 +77,6 @@ def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
-def parse_layer_structure(dim, state_dict):
- i = 0
- layer_structure = [1]
-
- while (key := "linear.{}.weight".format(i)) in state_dict:
- weight = state_dict[key]
- layer_structure.append(len(weight) // dim)
- i += 1
-
- return layer_structure
-
-
class Hypernetwork:
filename = None
name = None
@@ -135,17 +129,18 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
+
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
- HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
)
self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0)
- self.layer_structure = state_dict.get('layer_structure', None)
- self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
@@ -244,6 +239,7 @@ def stack_conds(conds):
return torch.stack(conds)
+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
--
cgit v1.2.3
From 2ce52d32e41fb523d1494f45073fd18496e52d35 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Wed, 19 Oct 2022 16:31:12 +0000
Subject: fix for #3086 failing to load any previous hypernet
---
modules/hypernetworks/hypernetwork.py | 60 ++++++++++++++++-------------------
1 file changed, 28 insertions(+), 32 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 7d519cd9..74300122 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -24,11 +24,10 @@ class HypernetworkModule(torch.nn.Module):
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
super().__init__()
- if layer_structure is not None:
- assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
- assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
- else:
- layer_structure = parse_layer_structure(dim, state_dict)
+
+ assert layer_structure is not None, "layer_structure mut not be None"
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
@@ -39,23 +38,30 @@ class HypernetworkModule(torch.nn.Module):
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
- try:
- self.load_state_dict(state_dict)
- except RuntimeError:
- self.try_load_previous(state_dict)
+ self.fix_old_state_dict(state_dict)
+ self.load_state_dict(state_dict)
else:
for layer in self.linear:
- layer.weight.data.normal_(mean = 0.0, std = 0.01)
+ layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
self.to(devices.device)
- def try_load_previous(self, state_dict):
- states = self.state_dict()
- states['linear.0.bias'].copy_(state_dict['linear1.bias'])
- states['linear.0.weight'].copy_(state_dict['linear1.weight'])
- states['linear.1.bias'].copy_(state_dict['linear2.bias'])
- states['linear.1.weight'].copy_(state_dict['linear2.weight'])
+ def fix_old_state_dict(self, state_dict):
+ changes = {
+ 'linear1.bias': 'linear.0.bias',
+ 'linear1.weight': 'linear.0.weight',
+ 'linear2.bias': 'linear.1.bias',
+ 'linear2.weight': 'linear.1.weight',
+ }
+
+ for fr, to in changes.items():
+ x = state_dict.get(fr, None)
+ if x is None:
+ continue
+
+ del state_dict[fr]
+ state_dict[to] = x
def forward(self, x):
return x + self.linear(x) * self.multiplier
@@ -71,18 +77,6 @@ def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
-def parse_layer_structure(dim, state_dict):
- i = 0
- layer_structure = [1]
-
- while (key := "linear.{}.weight".format(i)) in state_dict:
- weight = state_dict[key]
- layer_structure.append(len(weight) // dim)
- i += 1
-
- return layer_structure
-
-
class Hypernetwork:
filename = None
name = None
@@ -135,17 +129,18 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
+
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
- HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
)
self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0)
- self.layer_structure = state_dict.get('layer_structure', None)
- self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
@@ -244,6 +239,7 @@ def stack_conds(conds):
return torch.stack(conds)
+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
--
cgit v1.2.3
From 57eb1a64c85d995cacb4fa3832e87405bf6820b9 Mon Sep 17 00:00:00 2001
From: Alexandre Simard
Date: Wed, 19 Oct 2022 12:28:27 -0400
Subject: Update ui.py
---
modules/ui.py | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index d2e24880..c9a923ab 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -268,8 +268,13 @@ def calc_time_left(progress, threshold, label, force_display):
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
- if (eta_relative > threshold and progress > 0.02) or force_display:
- return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
+ if (eta_relative > threshold and progress > 0.02) or force_display:
+ if eta_relative > 3600:
+ return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
+ elif eta_relative > 60:
+ return label + time.strftime('%M:%S', time.gmtime(eta_relative))
+ else:
+ return label + time.strftime('%Ss', time.gmtime(eta_relative))
else:
return ""
@@ -285,7 +290,7 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
- time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display )
+ time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
if time_left != "":
shared.state.time_left_force_display = True
@@ -293,7 +298,7 @@ def check_progress_call(id_part):
progressbar = ""
if opts.show_progressbar:
- progressbar = f"""{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}
"""
+ progressbar = f"""{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}
"""
image = gr_show(False)
preview_visibility = gr_show(False)
--
cgit v1.2.3
From 1e4809b251d478a102fd980dcfc26e21d6d3730b Mon Sep 17 00:00:00 2001
From: Alexandre Simard
Date: Wed, 19 Oct 2022 12:53:23 -0400
Subject: Added a bit of padding to the left
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index c9a923ab..a2dbd41e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -298,7 +298,7 @@ def check_progress_call(id_part):
progressbar = ""
if opts.show_progressbar:
- progressbar = f"""{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}
"""
+ progressbar = f"""{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}
"""
image = gr_show(False)
preview_visibility = gr_show(False)
--
cgit v1.2.3
From 5e012e4dfa5dcfeade0394678cf14b70682dba6c Mon Sep 17 00:00:00 2001
From: timntorres
Date: Wed, 19 Oct 2022 06:17:47 -0700
Subject: Infotext saves more specific hypernet name.
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index ea926fc3..bcb0c32c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
--
cgit v1.2.3
From 46122c4ff6aadc0f96e657f88dbac7bbd9f9bf99 Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Wed, 19 Oct 2022 19:18:52 +0300
Subject: Send empty prompts as valid generation parameter
---
modules/generation_parameters_copypaste.py | 3 ---
1 file changed, 3 deletions(-)
(limited to 'modules')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index c27826b6..98d24406 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -45,10 +45,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
prompt += ("" if prompt == "" else "\n") + line
- if len(prompt) > 0:
res["Prompt"] = prompt
-
- if len(negative_prompt) > 0:
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
--
cgit v1.2.3
From b748b583c0b9f771c1be509175a6913e3f2ad97c Mon Sep 17 00:00:00 2001
From: Mackerel
Date: Wed, 19 Oct 2022 14:22:03 -0400
Subject: generation_parameters_copypaste.py: fix indent
---
modules/generation_parameters_copypaste.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 98d24406..0f041449 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -45,8 +45,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
prompt += ("" if prompt == "" else "\n") + line
- res["Prompt"] = prompt
- res["Negative prompt"] = negative_prompt
+ res["Prompt"] = prompt
+ res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
m = re_imagesize.match(v)
--
cgit v1.2.3
From eb7ba4b713ac2fb960ecf6365b1de0c89451e583 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 19 Oct 2022 19:50:46 +0100
Subject: update training header text
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 1573ef82..93c0767c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1247,7 +1247,7 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Tab(label="Train"):
- gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork wiki
")
+ gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork [wiki]
")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
--
cgit v1.2.3
From 4d663055ded968831ec97f047dfa8e94036cf1c1 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Wed, 19 Oct 2022 20:33:18 +0100
Subject: update ui with extra training options
---
modules/ui.py | 11 +++++++++--
1 file changed, 9 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 93c0767c..cdb9d335 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1206,6 +1206,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
with gr.Row():
with gr.Column(scale=3):
@@ -1219,6 +1220,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
with gr.Row():
with gr.Column(scale=3):
@@ -1247,14 +1249,17 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Tab(label="Train"):
- gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images
Initial learning rates: 0.005 for an Embedding, 0.00001 for Hypernetwork [wiki]
")
+ gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
- learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ with gr.Row():
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
+
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
@@ -1288,6 +1293,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name,
initialization_text,
nvpt,
+ overwrite_old_embedding,
],
outputs=[
train_embedding_name,
@@ -1303,6 +1309,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes,
new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
+ overwrite_old_hypernetwork,
],
outputs=[
train_hypernetwork_name,
--
cgit v1.2.3
From 8e7097d06a6a261580d34375c9d2a9e4ffc63ffa Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Wed, 19 Oct 2022 13:47:45 -0700
Subject: Added support for RunwayML inpainting model
---
modules/processing.py | 34 ++++++-
modules/sd_hijack_inpainting.py | 208 ++++++++++++++++++++++++++++++++++++++++
modules/sd_models.py | 16 +++-
modules/sd_samplers.py | 50 +++++++---
4 files changed, 293 insertions(+), 15 deletions(-)
create mode 100644 modules/sd_hijack_inpainting.py
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index bcb0c32c..a6c308f9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -546,7 +546,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device)
+ image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=image_conditioning)
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@@ -714,10 +723,31 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
+ if self.image_mask is not None:
+ conditioning_mask = np.array(self.image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ conditioning_mask = conditioning_mask.to(image.device)
+ conditioning_image = image * (1.0 - conditioning_mask)
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
new file mode 100644
index 00000000..7e5670d6
--- /dev/null
+++ b/modules/sd_hijack_inpainting.py
@@ -0,0 +1,208 @@
+import torch
+import numpy as np
+
+from tqdm import tqdm
+from einops import rearrange, repeat
+from omegaconf import ListConfig
+
+from types import MethodType
+
+import ldm.models.diffusion.ddpm
+import ldm.models.diffusion.ddim
+
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+
+# =================================================================================================
+# Monkey patch DDIMSampler methods from RunwayML repo directly.
+# Adapted from:
+# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
+# =================================================================================================
+@torch.no_grad()
+def sample(
+ self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list):
+ ctmp = elf.inpainting_fill == 2:
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
+ elif self.inpainting_fill == 3:
+ self.init_latent = self.init_latent * self.mask
+
+ if self.image_mask is not None:
+ conditioning_mask = np.array(self.image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ conditioning_mask = conditioning_mask.to(image.device)
+ conditioning_image = image * (1.0 - conditioning_mask)
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
+
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ x = create_random_tensors([opctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+
+@torch.no_grad()
+def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
+ for i in range(len(c[k]))
+ ]
+ else:
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+
+# =================================================================================================
+# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
+# Adapted from:
+# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
+# =================================================================================================
+
+@torch.no_grad()
+def get_unconditional_conditioning(self, batch_size, null_label=None):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ # todo: get null label from cond_stage_model
+ raise NotImplementedError()
+ c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
+ return c
+
+class LatentInpaintDiffusion(LatentDiffusion):
+ def __init__(
+ self,
+ concat_keys=("mask", "masked_image"),
+ masked_image_key="masked_image",
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.masked_image_key = masked_image_key
+ assert self.masked_image_key in concat_keys
+ self.concat_keys = concat_keys
+
+def should_hijack_inpainting(checkpoint_info):
+ return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
+
+def do_inpainting_hijack():
+ ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
+ ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
+ ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
+ ldm.models.diffusion.ddim.DDIMSampler.sample = sample
\ No newline at end of file
diff --git a/modules/sd_models.py b/modules/sd_models.py
index eae22e87..47836d25 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices
from modules.paths import models_path
+from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
@@ -211,6 +212,19 @@ def load_model():
print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config)
+
+ if should_hijack_inpainting(checkpoint_info):
+ do_inpainting_hijack()
+
+ # Hardcoded config for now...
+ sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
+ sd_config.model.params.use_ema = False
+ sd_config.model.params.conditioning_key = "hybrid"
+ sd_config.model.params.unet_config.params.in_channels = 9
+
+ # Create a "fake" config with a different name so that we know to unload it when switching models.
+ checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
+
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
@@ -234,7 +248,7 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
+ if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
shared.sd_model = load_model()
return shared.sd_model
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index b58e810b..9d3cf289 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -136,9 +136,15 @@ class VanillaStableDiffusionSampler:
if self.stop_at is not None and self.step > self.stop_at:
raise InterruptedException
+ # Have to unwrap the inpainting conditioning here to perform pre-preocessing
+ image_conditioning = None
+ if isinstance(cond, dict):
+ image_conditioning = cond["c_concat"][0]
+ cond = cond["c_crossattn"][0]
+ unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
@@ -157,6 +163,10 @@ class VanillaStableDiffusionSampler:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
+ if image_conditioning is not None:
+ cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None:
@@ -182,7 +192,7 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
self.initialize(p)
@@ -202,7 +212,7 @@ class VanillaStableDiffusionSampler:
return samples
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p)
self.init_latent = None
@@ -210,6 +220,11 @@ class VanillaStableDiffusionSampler:
steps = steps or p.steps
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
# existing code fails with certain step counts, like 9
try:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
@@ -228,7 +243,7 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
- def forward(self, x, sigma, uncond, cond, cond_scale):
+ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
@@ -239,28 +254,29 @@ class CFGDenoiser(torch.nn.Module):
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
+ x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
@@ -361,7 +377,7 @@ class KDiffusionSampler:
return extra_params_kwargs
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
if p.sampler_noise_scheduler_override:
@@ -389,11 +405,16 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
if p.sampler_noise_scheduler_override:
@@ -414,7 +435,12 @@ class KDiffusionSampler:
else:
extra_params_kwargs['sigmas'] = sigmas
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
--
cgit v1.2.3
From 0719c10bf1b817364a498ee11b90d30d3d527344 Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Wed, 19 Oct 2022 13:56:26 -0700
Subject: Fixed copying mistake
---
modules/sd_hijack_inpainting.py | 79 +++++++++++++----------------------------
1 file changed, 25 insertions(+), 54 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 7e5670d6..d4d28d2e 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -19,63 +19,35 @@ from ldm.models.diffusion.ddim import DDIMSampler, noise_like
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
# =================================================================================================
@torch.no_grad()
-def sample(
- self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
+def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
- ctmp = elf.inpainting_fill == 2:
- self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
- elif self.inpainting_fill == 3:
- self.init_latent = self.init_latent * self.mask
-
- if self.image_mask is not None:
- conditioning_mask = np.array(self.image_mask.convert("L"))
- conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
- conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
-
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
- else:
- conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
-
- # Create another latent image, this time with a masked version of the original input.
- conditioning_mask = conditioning_mask.to(image.device)
- conditioning_image = image * (1.0 - conditioning_mask)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
-
- # Create the concatenated conditioning tensor to be fed to `c_concat`
- conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
- conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
- self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
- self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
- x = create_random_tensors([opctmp[0]
+ ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
@@ -106,7 +78,6 @@ def sample(
)
return samples, intermediates
-
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
--
cgit v1.2.3
From dde9f960727bfe151d418e43685a2881cf580a17 Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Wed, 19 Oct 2022 14:14:24 -0700
Subject: added support for ddim img2img
---
modules/sd_samplers.py | 6 ++++++
1 file changed, 6 insertions(+)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 9d3cf289..d270e4df 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -208,6 +208,12 @@ class VanillaStableDiffusionSampler:
self.init_latent = x
self.step = 0
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
+
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
--
cgit v1.2.3
From c418467c03db916c3e5312e6ac4a67365e196dbd Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Wed, 19 Oct 2022 15:09:43 -0700
Subject: Don't compute latent mask if were not using it. Also added support
for fixed highres_fix generation.
---
modules/processing.py | 72 +++++++++++++++++++++++++++++++-------------------
modules/sd_samplers.py | 4 +++
2 files changed, 49 insertions(+), 27 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index a6c308f9..684e5833 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -541,12 +541,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
- self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
-
- if not self.enable_hr:
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
-
+ def create_dummy_mask(self, x):
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
@@ -555,11 +551,23 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=image_conditioning)
+ else:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ image_conditioning = torch.zeros(x.shape[0], 5, x.shape[-2], x.shape[-1], dtype=x.dtype, device=x.device)
+
+ return image_conditioning
+
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
+
+ if not self.enable_hr:
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
@@ -596,7 +604,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
return samples
@@ -723,26 +731,36 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- if self.image_mask is not None:
- conditioning_mask = np.array(self.image_mask.convert("L"))
- conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
- conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+ conditioning_key = self.sampler.conditioning_key
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
+ if conditioning_key in {'hybrid', 'concat'}:
+ if self.image_mask is not None:
+ conditioning_mask = np.array(self.image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ conditioning_mask = conditioning_mask.to(image.device)
+ conditioning_image = image * (1.0 - conditioning_mask)
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
else:
- conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
-
- # Create another latent image, this time with a masked version of the original input.
- conditioning_mask = conditioning_mask.to(image.device)
- conditioning_image = image * (1.0 - conditioning_mask)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
-
- # Create the concatenated conditioning tensor to be fed to `c_concat`
- conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
- conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
- self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
- self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
+ self.image_conditioning = torch.zeros(
+ self.init_latent.shape[0], 5, self.init_latent.shape[-2], self.init_latent.shape[-1],
+ dtype=self.init_latent.dtype,
+ device=self.init_latent.device
+ )
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index d270e4df..c21be26e 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler:
self.config = None
self.last_latent = None
+ self.conditioning_key = sd_model.model.conditioning_key
+
def number_of_needed_noises(self, p):
return 0
@@ -328,6 +330,8 @@ class KDiffusionSampler:
self.config = None
self.last_latent = None
+ self.conditioning_key = sd_model.model.conditioning_key
+
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
--
cgit v1.2.3
From d6ea5841374a28f3f6deb73abc251c8f0bcb240f Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:07:57 +0100
Subject: change html output
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 7d519cd9..73c1cb80 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -380,7 +380,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
Loss: {mean_loss:.7f}
Step: {hypernetwork.step}
Last prompt: {html.escape(entries[0].cond_text)}
-Last saved embedding: {html.escape(last_saved_file)}
+Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
--
cgit v1.2.3
From 166be3919b817cee5e702fd01c34afe9081b952c Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:09:40 +0100
Subject: allow overwrite old hn
---
modules/hypernetworks/ui.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 08f75f15..f45345ea 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -10,9 +10,10 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure)))
--
cgit v1.2.3
From 0087079c2d487b67b06ffc30f36ce486a74e6318 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:10:59 +0100
Subject: allow overwrite old embedding
---
modules/textual_inversion/textual_inversion.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 3be69562..5776778b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -153,7 +153,7 @@ class EmbeddingDatabase:
return None, None
-def create_embedding(name, num_vectors_per_token, init_text='*'):
+def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
--
cgit v1.2.3
From 632e8d660293081cadb145d8062e5aff0a4a8f0d Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:19:40 +0100
Subject: split learn rates
---
modules/ui.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index cdb9d335..d07184ee 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1342,7 +1342,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
- learn_rate,
+ embedding_learn_rate,
batch_size,
dataset_directory,
log_directory,
@@ -1367,7 +1367,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
- learn_rate,
+ hypernetwork_learn_rate,
batch_size,
dataset_directory,
log_directory,
--
cgit v1.2.3
From c3835ec85cbb44fa3c46fa871c622b6fee235c89 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:24:24 +0100
Subject: pass overwrite old flag
---
modules/textual_inversion/ui.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index 36881e7a..e712284d 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
-def create_embedding(name, initialization_text, nvpt):
- filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
+def create_embedding(name, initialization_text, nvpt, overwrite_old):
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
--
cgit v1.2.3
From 4d6b9f76a55fd0ac0f72634071032dd9c6efb409 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:27:16 +0100
Subject: reorder create_hypernetwork params
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index d07184ee..322c082b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1307,9 +1307,9 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
+ overwrite_old_hypernetwork,
new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
- overwrite_old_hypernetwork,
],
outputs=[
train_hypernetwork_name,
--
cgit v1.2.3
From fbcce66601994f6ed370db36d9c238840fed6bd2 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:46:54 +0100
Subject: add existing caption file handling
---
modules/textual_inversion/preprocess.py | 32 ++++++++++++++++++++++++--------
1 file changed, 24 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 886cf0c3..5c43fe13 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- def save_pic_with_caption(image, index):
+ def save_pic_with_caption(image, index, existing_caption=None):
caption = ""
if process_caption:
@@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))
+ if preprocess_txt_action == 'prepend' and existing_caption:
+ caption = existing_caption + ' ' + caption
+ elif preprocess_txt_action == 'append' and existing_caption:
+ caption = caption + ' ' + existing_caption
+ elif preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
+
+ caption = caption.strip()
+
if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
subindex[0] += 1
- def save_pic(image, index):
+ def save_pic(image, index, existing_caption=None):
save_pic_with_caption(image, index)
if process_flip:
- save_pic_with_caption(ImageOps.mirror(image), index)
+ save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
@@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
except Exception:
continue
+ existing_caption = None
+
+ try:
+ existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
+ except Exception as e:
+ print(e)
+
if shared.state.interrupted:
break
@@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
img = img.resize((width, height * img.height // img.width))
top = img.crop((0, 0, width, height))
- save_pic(top, index)
+ save_pic(top, index, existing_caption=existing_caption)
bot = img.crop((0, img.height - height, width, img.height))
- save_pic(bot, index)
+ save_pic(bot, index, existing_caption=existing_caption)
elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height))
left = img.crop((0, 0, width, height))
- save_pic(left, index)
+ save_pic(left, index, existing_caption=existing_caption)
right = img.crop((img.width - width, 0, img.width, height))
- save_pic(right, index)
+ save_pic(right, index, existing_caption=existing_caption)
else:
img = images.resize_image(1, img, width, height)
- save_pic(img, index)
+ save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob()
--
cgit v1.2.3
From ab353b141df8eee042b0964bcb645015dabf3459 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:48:07 +0100
Subject: link existing txt option
---
modules/ui.py | 2 ++
1 file changed, 2 insertions(+)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 322c082b..7f52ac0c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1234,6 +1234,7 @@ def create_ui(wrap_gradio_gpu_call):
process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', choices=['ignore', 'copy', 'prepend', 'append'])
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
@@ -1326,6 +1327,7 @@ def create_ui(wrap_gradio_gpu_call):
process_dst,
process_width,
process_height,
+ preprocess_txt_action,
process_flip,
process_split,
process_caption,
--
cgit v1.2.3
From 9b65c4ecf4f8eb6187ee721918adebe68e9bc631 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:49:23 +0100
Subject: pass preprocess_txt_action param
---
modules/textual_inversion/preprocess.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 5c43fe13..3713bc89 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
try:
if process_caption:
shared.interrogator.load()
@@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
- preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru)
finally:
@@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
-def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
--
cgit v1.2.3
From 55d8c6cce6d3aef848b9f194adad2ce53064d8b7 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 00:53:29 +0100
Subject: default to ignore existing captions
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 7f52ac0c..bd5f1b05 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1234,7 +1234,7 @@ def create_ui(wrap_gradio_gpu_call):
process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
- preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', choices=['ignore', 'copy', 'prepend', 'append'])
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
--
cgit v1.2.3
From 6f98e89486f55b0e4657e96ce640cf1c4675d187 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Thu, 20 Oct 2022 00:10:45 +0000
Subject: update
---
modules/hypernetworks/hypernetwork.py | 29 +++++++++++++++--------
modules/hypernetworks/ui.py | 3 ++-
modules/ui.py | 43 +++++++++++++++++++----------------
3 files changed, 44 insertions(+), 31 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 74300122..7d617680 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -22,16 +22,20 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
- def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
+ def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
super().__init__()
- assert layer_structure is not None, "layer_structure mut not be None"
+ assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+ if activation_func == "relu":
+ linears.append(torch.nn.ReLU())
+ if activation_func == "leakyrelu":
+ linears.append(torch.nn.LeakyReLU())
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -42,8 +46,9 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- layer.weight.data.normal_(mean=0.0, std=0.01)
- layer.bias.data.zero_()
+ if not "ReLU" in layer.__str__():
+ layer.weight.data.normal_(mean=0.0, std=0.01)
+ layer.bias.data.zero_()
self.to(devices.device)
@@ -69,7 +74,8 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- layer_structure += [layer.weight, layer.bias]
+ if not "ReLU" in layer.__str__():
+ layer_structure += [layer.weight, layer.bias]
return layer_structure
@@ -81,7 +87,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
self.filename = None
self.name = name
self.layers = {}
@@ -90,11 +96,12 @@ class Hypernetwork:
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
self.add_layer_norm = add_layer_norm
+ self.activation_func = activation_func
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
+ HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
)
def weights(self):
@@ -117,6 +124,7 @@ class Hypernetwork:
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['activation_func'] = self.activation_func
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -131,12 +139,13 @@ class Hypernetwork:
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ self.activation_func = state_dict.get('activation_func', None)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
- HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
)
self.name = state_dict.get('name', self.name)
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 08f75f15..83f9547b 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -10,7 +10,7 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
+def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
@@ -22,6 +22,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
add_layer_norm=add_layer_norm,
+ activation_func=activation_func,
)
hypernet.save(fn)
diff --git a/modules/ui.py b/modules/ui.py
index d2e24880..8751fa9c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -5,43 +5,44 @@ import json
import math
import mimetypes
import os
+import platform
import random
+import subprocess as sp
import sys
import tempfile
import time
import traceback
-import platform
-import subprocess as sp
from functools import partial, reduce
+import gradio as gr
+import gradio.routes
+import gradio.utils
import numpy as np
+import piexif
import torch
from PIL import Image, PngImagePlugin
-import piexif
-import gradio as gr
-import gradio.utils
-import gradio.routes
-
-from modules import sd_hijack, sd_models, localization
+from modules import localization, sd_hijack, sd_models
from modules.paths import script_path
-from modules.shared import opts, cmd_opts, restricted_opts
+from modules.shared import cmd_opts, opts, restricted_opts
+
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
-import modules.shared as shared
-from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.sd_hijack import model_hijack
+
+import modules.codeformer_model
+import modules.generation_parameters_copypaste
+import modules.gfpgan_model
+import modules.hypernetworks.ui
+import modules.images_history as img_his
import modules.ldsr_model
import modules.scripts
-import modules.gfpgan_model
-import modules.codeformer_model
+import modules.shared as shared
import modules.styles
-import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
from modules import prompt_parser
from modules.images import save_image
-import modules.textual_inversion.ui
-import modules.hypernetworks.ui
-import modules.images_history as img_his
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -268,8 +269,8 @@ def calc_time_left(progress, threshold, label, force_display):
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
- if (eta_relative > threshold and progress > 0.02) or force_display:
- return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
+ if (eta_relative > threshold and progress > 0.02) or force_display:
+ return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
else:
return ""
@@ -1219,6 +1220,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["relu", "leakyrelu"])
with gr.Row():
with gr.Column(scale=3):
@@ -1303,6 +1305,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes,
new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
+ new_hypernetwork_activation_func,
],
outputs=[
train_hypernetwork_name,
--
cgit v1.2.3
From ba469343e6a1c6e23e82acf5feb65c6101dacbb2 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Thu, 20 Oct 2022 00:17:04 +0000
Subject: align ui.py imports with upstream
---
modules/ui.py | 37 ++++++++++++++++++-------------------
1 file changed, 18 insertions(+), 19 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 987b1d7d..913b23b4 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -5,44 +5,43 @@ import json
import math
import mimetypes
import os
-import platform
import random
-import subprocess as sp
import sys
import tempfile
import time
import traceback
+import platform
+import subprocess as sp
from functools import partial, reduce
-import gradio as gr
-import gradio.routes
-import gradio.utils
import numpy as np
-import piexif
import torch
from PIL import Image, PngImagePlugin
+import piexif
-from modules import localization, sd_hijack, sd_models
-from modules.paths import script_path
-from modules.shared import cmd_opts, opts, restricted_opts
+import gradio as gr
+import gradio.utils
+import gradio.routes
+from modules import sd_hijack, sd_models, localization
+from modules.paths import script_path
+from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
-
-import modules.codeformer_model
-import modules.generation_parameters_copypaste
-import modules.gfpgan_model
-import modules.hypernetworks.ui
-import modules.images_history as img_his
+import modules.shared as shared
+from modules.sd_samplers import samplers, samplers_for_img2img
+from modules.sd_hijack import model_hijack
import modules.ldsr_model
import modules.scripts
-import modules.shared as shared
+import modules.gfpgan_model
+import modules.codeformer_model
import modules.styles
-import modules.textual_inversion.ui
+import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
-from modules.sd_hijack import model_hijack
-from modules.sd_samplers import samplers, samplers_for_img2img
+import modules.textual_inversion.ui
+import modules.hypernetworks.ui
+import modules.images_history as img_his
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
--
cgit v1.2.3
From 858462f719c22ca9f24b94a41699653c34b5f4fb Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Thu, 20 Oct 2022 02:57:18 +0100
Subject: do caption copy for both flips
---
modules/textual_inversion/preprocess.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 3713bc89..6bba3852 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -82,7 +82,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
subindex[0] += 1
def save_pic(image, index, existing_caption=None):
- save_pic_with_caption(image, index)
+ save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
--
cgit v1.2.3
From aa7ff2a1972f3865883e10ba28c5414cdebe8e3b Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Wed, 19 Oct 2022 21:46:13 -0700
Subject: Fixed non-square highres fix generation
---
modules/processing.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 684e5833..3caac25e 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -541,10 +541,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def create_dummy_mask(self, x):
+ def create_dummy_mask(self, x, first_phase: bool = False):
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ height = self.firstphase_height if first_phase else self.height
+ width = self.firstphase_width if first_phase else self.width
+
# The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device)
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
@@ -567,7 +570,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, first_phase=True))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
--
cgit v1.2.3
From 930b4c64f7dbce6918894d53538003e5959fd022 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 20 Oct 2022 08:18:02 +0300
Subject: allow float sizes for hypernet's layer_structure
---
modules/hypernetworks/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 08f75f15..e0741d08 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -15,7 +15,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
- layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure)))
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name,
--
cgit v1.2.3
From f8733ad08be08bafb40f4299785590e11f049e96 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Thu, 20 Oct 2022 11:07:37 +0000
Subject: add linear as a act func (option for doin nothing)
---
modules/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 913b23b4..716f14b8 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1224,7 +1224,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["relu", "leakyrelu"])
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row():
with gr.Column(scale=3):
--
cgit v1.2.3
From 9681419e422515e42444e0174355b760645a846f Mon Sep 17 00:00:00 2001
From: Milly
Date: Thu, 20 Oct 2022 16:53:46 +0900
Subject: train: fixed preprocess image ratio
---
modules/textual_inversion/preprocess.py | 54 +++++++++++++++++++++------------
1 file changed, 35 insertions(+), 19 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 886cf0c3..2743bdeb 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,5 +1,6 @@
import os
from PIL import Image, ImageOps
+import math
import platform
import sys
import tqdm
@@ -38,6 +39,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
+ split_threshold = 0.5
+ overlap_ratio = 0.2
assert src != dst, 'same directory specified as source and destination'
@@ -78,6 +81,29 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index)
+ def split_pic(image, inverse_xy):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
+ else:
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
+
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
@@ -89,26 +115,16 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
if shared.state.interrupted:
break
- ratio = img.height / img.width
- is_tall = ratio > 1.35
- is_wide = ratio < 1 / 1.35
-
- if process_split and is_tall:
- img = img.resize((width, height * img.height // img.width))
-
- top = img.crop((0, 0, width, height))
- save_pic(top, index)
-
- bot = img.crop((0, img.height - height, width, img.height))
- save_pic(bot, index)
- elif process_split and is_wide:
- img = img.resize((width * img.width // img.height, height))
-
- left = img.crop((0, 0, width, height))
- save_pic(left, index)
+ if img.height > img.width:
+ ratio = (img.width * height) / (img.height * width)
+ inverse_xy = False
+ else:
+ ratio = (img.height * width) / (img.width * height)
+ inverse_xy = True
- right = img.crop((img.width - width, 0, img.width, height))
- save_pic(right, index)
+ if process_split and ratio < 1.0 and ratio <= split_threshold:
+ for splitted in split_pic(img, inverse_xy):
+ save_pic(splitted, index)
else:
img = images.resize_image(1, img, width, height)
save_pic(img, index)
--
cgit v1.2.3
From 85dd62c4c7635b8e21a75f140d093036069e97a1 Mon Sep 17 00:00:00 2001
From: Milly
Date: Thu, 20 Oct 2022 22:56:45 +0900
Subject: train: ui: added `Split image threshold` and `Split image overlap
ratio` to preprocess
---
modules/textual_inversion/preprocess.py | 10 +++++-----
modules/ui.py | 16 ++++++++++++++--
2 files changed, 19 insertions(+), 7 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 2743bdeb..c8df8aa0 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -12,7 +12,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
try:
if process_caption:
shared.interrogator.load()
@@ -22,7 +22,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
- preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+ preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
finally:
@@ -34,13 +34,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
-def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
width = process_width
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
- split_threshold = 0.5
- overlap_ratio = 0.2
+ split_threshold = max(0.0, min(1.0, split_threshold))
+ overlap_ratio = max(0.0, min(0.9, overlap_ratio))
assert src != dst, 'same directory specified as source and destination'
diff --git a/modules/ui.py b/modules/ui.py
index a2dbd41e..bc7f3330 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1240,10 +1240,14 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
- process_split = gr.Checkbox(label='Split oversized images into two')
+ process_split = gr.Checkbox(label='Split oversized images')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
+ with gr.Row(visible=False) as process_split_extra_row:
+ process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
+ process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1251,6 +1255,12 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
run_preprocess = gr.Button(value="Preprocess", variant='primary')
+ process_split.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_split],
+ outputs=[process_split_extra_row],
+ )
+
with gr.Tab(label="Train"):
gr.HTML(value="Train an embedding; must specify a directory with a set of 1:1 ratio images
")
with gr.Row():
@@ -1327,7 +1337,9 @@ def create_ui(wrap_gradio_gpu_call):
process_flip,
process_split,
process_caption,
- process_caption_deepbooru
+ process_caption_deepbooru,
+ process_split_threshold,
+ process_overlap_ratio,
],
outputs=[
ti_output,
--
cgit v1.2.3
From d8acd34f66ab35a91f10d66330bcc95a83bfcac6 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Thu, 20 Oct 2022 23:43:03 +0900
Subject: generalized some functions and option for ignoring first layer
---
modules/hypernetworks/hypernetwork.py | 23 +++++++++++++++--------
1 file changed, 15 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 7d617680..3a44b377 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -21,21 +21,27 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
-
+ activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish}
+
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
-
+
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
- if activation_func == "relu":
- linears.append(torch.nn.ReLU())
- if activation_func == "leakyrelu":
- linears.append(torch.nn.LeakyReLU())
+ # if skip_first_layer because first parameters potentially contain negative values
+ if i < 1: continue
+ if activation_func in HypernetworkModule.activation_dict:
+ linears.append(HypernetworkModule.activation_dict[activation_func]())
+ else:
+ print("Invalid key {} encountered as activation function!".format(activation_func))
+ # if use_dropout:
+ linears.append(torch.nn.Dropout(p=0.3))
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if isinstance(layer, torch.nn.Linear):
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
@@ -298,7 +304,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ # if optimizer == "Adam": or else Adam / AdamW / etc...
+ optimizer = torch.optim.Adam(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
--
cgit v1.2.3
From a71e0212363979c7cbbb797c9fbd5f8cd03b29d3 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Thu, 20 Oct 2022 23:48:52 +0900
Subject: only linear
---
modules/hypernetworks/hypernetwork.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3a44b377..905cbeef 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -35,13 +35,13 @@ class HypernetworkModule(torch.nn.Module):
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values
- if i < 1: continue
+ # if i < 1: continue
if activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]())
else:
print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout:
- linears.append(torch.nn.Dropout(p=0.3))
+ # linears.append(torch.nn.Dropout(p=0.3))
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if isinstance(layer, torch.nn.Linear):
layer_structure += [layer.weight, layer.bias]
return layer_structure
@@ -304,8 +304,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- # if optimizer == "Adam": or else Adam / AdamW / etc...
- optimizer = torch.optim.Adam(weights, lr=scheduler.learn_rate)
+ # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
--
cgit v1.2.3
From 108be15500aac590b4e00420635d7b61fccfa530 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Fri, 21 Oct 2022 01:00:41 +0900
Subject: fix bugs and optimizations
---
modules/hypernetworks/hypernetwork.py | 105 +++++++++++++++++++---------------
1 file changed, 59 insertions(+), 46 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 905cbeef..893ba110 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values
# if i < 1: continue
+ if add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
if activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]())
else:
print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout:
# linears.append(torch.nn.Dropout(p=0.3))
- if add_layer_norm:
- linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
self.linear = torch.nn.Sequential(*linears)
@@ -115,11 +115,24 @@ class Hypernetwork:
for k, layers in self.layers.items():
for layer in layers:
- layer.train()
res += layer.trainables()
return res
+ def eval(self):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.eval()
+ for items in self.weights():
+ items.requires_grad = False
+
+ def train(self):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.train()
+ for items in self.weights():
+ items.requires_grad = True
+
def save(self, filename):
state_dict = {}
@@ -290,10 +303,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork
- weights = hypernetwork.weights()
- for weight in weights:
- weight.requires_grad = True
-
losses = torch.zeros((32,))
last_saved_file = ""
@@ -304,10 +313,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
- optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
+ hypernetwork.train()
for i, entries in pbar:
hypernetwork.step = i + ititial_step
@@ -328,8 +337,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item()
- optimizer.zero_grad()
+ optimizer.zero_grad(set_to_none=True)
loss.backward()
+ del loss
optimizer.step()
mean_loss = losses.mean()
if torch.isnan(mean_loss):
@@ -346,44 +356,47 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
+ torch.cuda.empty_cache()
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+ with torch.no_grad():
+ hypernetwork.eval()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
- optimizer.zero_grad()
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
-
- preview_text = p.prompt
-
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images)>0 else None
-
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
-
- if image is not None:
- shared.state.current_image = image
- image.save(last_saved_image)
- last_saved_image += f", prompt: {preview_text}"
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images)>0 else None
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ if image is not None:
+ shared.state.current_image = image
+ image.save(last_saved_image)
+ last_saved_image += f", prompt: {preview_text}"
+
+ hypernetwork.train()
shared.state.job_no = hypernetwork.step
--
cgit v1.2.3
From f89829ec3a0baceb445451ad98d4fb4323e922aa Mon Sep 17 00:00:00 2001
From: aria1th <35677394+aria1th@users.noreply.github.com>
Date: Fri, 21 Oct 2022 01:37:11 +0900
Subject: Revert "fix bugs and optimizations"
This reverts commit 108be15500aac590b4e00420635d7b61fccfa530.
---
modules/hypernetworks/hypernetwork.py | 105 +++++++++++++++-------------------
1 file changed, 46 insertions(+), 59 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 893ba110..905cbeef 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values
# if i < 1: continue
- if add_layer_norm:
- linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
if activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]())
else:
print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout:
# linears.append(torch.nn.Dropout(p=0.3))
+ if add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
self.linear = torch.nn.Sequential(*linears)
@@ -115,24 +115,11 @@ class Hypernetwork:
for k, layers in self.layers.items():
for layer in layers:
+ layer.train()
res += layer.trainables()
return res
- def eval(self):
- for k, layers in self.layers.items():
- for layer in layers:
- layer.eval()
- for items in self.weights():
- items.requires_grad = False
-
- def train(self):
- for k, layers in self.layers.items():
- for layer in layers:
- layer.train()
- for items in self.weights():
- items.requires_grad = True
-
def save(self, filename):
state_dict = {}
@@ -303,6 +290,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
+
losses = torch.zeros((32,))
last_saved_file = ""
@@ -313,10 +304,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate)
+ # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- hypernetwork.train()
for i, entries in pbar:
hypernetwork.step = i + ititial_step
@@ -337,9 +328,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item()
- optimizer.zero_grad(set_to_none=True)
+ optimizer.zero_grad()
loss.backward()
- del loss
optimizer.step()
mean_loss = losses.mean()
if torch.isnan(mean_loss):
@@ -356,47 +346,44 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- torch.cuda.empty_cache()
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
- with torch.no_grad():
- hypernetwork.eval()
- shared.sd_model.cond_stage_model.to(devices.device)
- shared.sd_model.first_stage_model.to(devices.device)
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
-
- preview_text = p.prompt
-
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images)>0 else None
-
- if unload:
- shared.sd_model.cond_stage_model.to(devices.cpu)
- shared.sd_model.first_stage_model.to(devices.cpu)
-
- if image is not None:
- shared.state.current_image = image
- image.save(last_saved_image)
- last_saved_image += f", prompt: {preview_text}"
-
- hypernetwork.train()
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images)>0 else None
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ if image is not None:
+ shared.state.current_image = image
+ image.save(last_saved_image)
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
--
cgit v1.2.3
From 92a17a7a4a13fceb3c3e25a2e854b2a7dd6eb5df Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Thu, 20 Oct 2022 09:45:03 -0700
Subject: Made dummy latents smaller. Minor code cleanups
---
modules/processing.py | 7 ++++---
modules/sd_samplers.py | 6 ++++--
2 files changed, 8 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 3caac25e..539cde38 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -557,7 +557,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
- image_conditioning = torch.zeros(x.shape[0], 5, x.shape[-2], x.shape[-1], dtype=x.dtype, device=x.device)
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
return image_conditioning
@@ -759,8 +760,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
else:
self.image_conditioning = torch.zeros(
- self.init_latent.shape[0], 5, self.init_latent.shape[-2], self.init_latent.shape[-1],
- dtype=self.init_latent.dtype,
+ self.init_latent.shape[0], 5, 1, 1,
+ dtype=self.init_latent.dtype,
device=self.init_latent.device
)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index c21be26e..cc682593 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -138,7 +138,7 @@ class VanillaStableDiffusionSampler:
if self.stop_at is not None and self.step > self.stop_at:
raise InterruptedException
- # Have to unwrap the inpainting conditioning here to perform pre-preocessing
+ # Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None
if isinstance(cond, dict):
image_conditioning = cond["c_concat"][0]
@@ -146,7 +146,7 @@ class VanillaStableDiffusionSampler:
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
@@ -165,6 +165,8 @@ class VanillaStableDiffusionSampler:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
+ # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
+ # Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
--
cgit v1.2.3
From d1cb08bfb221cd1b0cfc6078162b4e206ea80a5c Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Thu, 20 Oct 2022 22:49:06 +0300
Subject: fix skip and interrupt for highres. fix option
---
modules/processing.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index bcb0c32c..6324ca91 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -587,9 +587,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
-
- return samples
+ return self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) or samples
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
--
cgit v1.2.3
From 708c3a7bd8ce68cbe1aa7c268e5a4b1980affc9f Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Thu, 20 Oct 2022 13:28:43 -0700
Subject: Added PLMS hijack and made sure to always replace methods
---
modules/sd_hijack_inpainting.py | 163 ++++++++++++++++++++++++++++++++++++++--
modules/sd_models.py | 3 +-
2 files changed, 157 insertions(+), 9 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index d4d28d2e..43938071 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -1,16 +1,14 @@
import torch
-import numpy as np
-from tqdm import tqdm
-from einops import rearrange, repeat
+from einops import repeat
from omegaconf import ListConfig
-from types import MethodType
-
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
# =================================================================================================
@@ -19,7 +17,7 @@ from ldm.models.diffusion.ddim import DDIMSampler, noise_like
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
# =================================================================================================
@torch.no_grad()
-def sample(self,
+def sample_ddim(self,
S,
batch_size,
shape,
@@ -132,6 +130,153 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return x_prev, pred_x0
+# =================================================================================================
+# Monkey patch PLMSSampler methods.
+# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
+# Adapted from:
+# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
+# =================================================================================================
+@torch.no_grad()
+def sample_plms(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list):
+ ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+
+@torch.no_grad()
+def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
+ for i in range(len(c[k]))
+ ]
+ else:
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
+
# =================================================================================================
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
# Adapted from:
@@ -175,5 +320,9 @@ def should_hijack_inpainting(checkpoint_info):
def do_inpainting_hijack():
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
+
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
- ldm.models.diffusion.ddim.DDIMSampler.sample = sample
\ No newline at end of file
+ ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
+
+ ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
+ ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
\ No newline at end of file
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 47836d25..7072db08 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -214,8 +214,6 @@ def load_model():
sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
- do_inpainting_hijack()
-
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.use_ema = False
@@ -225,6 +223,7 @@ def load_model():
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
+ do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
--
cgit v1.2.3
From d23a46ceaa76af2847f11172f32c92665c268b1b Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Thu, 20 Oct 2022 23:49:14 +0300
Subject: Different approach to skip/interrupt with highres fix
---
modules/processing.py | 4 +++-
modules/sd_samplers.py | 4 ++++
2 files changed, 7 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 6324ca91..bcb0c32c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -587,7 +587,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- return self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) or samples
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
+
+ return samples
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index b58e810b..7ff77c01 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -196,6 +196,7 @@ class VanillaStableDiffusionSampler:
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
+ self.last_latent = x
self.step = 0
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
@@ -206,6 +207,7 @@ class VanillaStableDiffusionSampler:
self.initialize(p)
self.init_latent = None
+ self.last_latent = x
self.step = 0
steps = steps or p.steps
@@ -388,6 +390,7 @@ class KDiffusionSampler:
extra_params_kwargs['sigmas'] = sigma_sched
self.model_wrap_cfg.init_latent = x
+ self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
@@ -414,6 +417,7 @@ class KDiffusionSampler:
else:
extra_params_kwargs['sigmas'] = sigmas
+ self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
--
cgit v1.2.3
From 49533eed9e3aad19e9868ee140708baec4fd44be Mon Sep 17 00:00:00 2001
From: random_thoughtss
Date: Thu, 20 Oct 2022 16:01:27 -0700
Subject: XY grid correctly re-assignes model when config changes
---
modules/sd_models.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 7072db08..fea84630 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -204,9 +204,9 @@ def load_model_weights(model, checkpoint_info):
model.sd_checkpoint_info = checkpoint_info
-def load_model():
+def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
- checkpoint_info = select_checkpoint()
+ checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
@@ -249,7 +249,7 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
- shared.sd_model = load_model()
+ shared.sd_model = load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
--
cgit v1.2.3
From 45872181902ada06267e2de601586d512cf5df1a Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 09:00:39 +0300
Subject: updated readme and some small stylistic changes to code
---
modules/processing.py | 14 ++++++--------
modules/sd_hijack_inpainting.py | 3 +++
2 files changed, 9 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 539cde38..21786968 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -540,11 +540,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
-
- def create_dummy_mask(self, x, first_phase: bool = False):
+ def create_dummy_mask(self, x, width=None, height=None):
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- height = self.firstphase_height if first_phase else self.height
- width = self.firstphase_width if first_phase else self.width
+ height = height or self.height
+ width = width or self.width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
@@ -571,7 +570,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, first_phase=True))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
@@ -634,6 +633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpainting_mask_invert = inpainting_mask_invert
self.mask = None
self.nmask = None
+ self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
@@ -735,9 +735,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- conditioning_key = self.sampler.conditioning_key
-
- if conditioning_key in {'hybrid', 'concat'}:
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
if self.image_mask is not None:
conditioning_mask = np.array(self.image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 43938071..fd92a335 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -301,6 +301,7 @@ def get_unconditional_conditioning(self, batch_size, null_label=None):
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
+
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
@@ -314,9 +315,11 @@ class LatentInpaintDiffusion(LatentDiffusion):
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
+
def should_hijack_inpainting(checkpoint_info):
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
+
def do_inpainting_hijack():
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
--
cgit v1.2.3
From 74088c2a06a975092806362aede22f82716cb011 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Thu, 20 Oct 2022 08:18:02 +0300
Subject: allow float sizes for hypernet's layer_structure
---
modules/hypernetworks/ui.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 08f75f15..e0741d08 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -15,7 +15,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
- layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure)))
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name,
--
cgit v1.2.3
From 60872c5b404114336f9ca0c671ba88fa4a8201c9 Mon Sep 17 00:00:00 2001
From: winterspringsummer
Date: Thu, 20 Oct 2022 19:10:32 +0900
Subject: Fixed path issue while extras batch processing
---
modules/extras.py | 12 ++++++++----
1 file changed, 8 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index b853fa5b..f9796624 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -118,10 +118,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
-
- images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
- forced_filename=image_name if opts.use_original_name_batch else None)
+
+ if opts.use_original_name_batch and image_name != None:
+ basename = os.path.splitext(os.path.basename(image_name))[0]
+ else:
+ basename = ''
+
+ images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo:
image.info = existing_pnginfo
--
cgit v1.2.3
From fb5a8cf0d9ed027ea3aa2e5422c946d8e6e72efe Mon Sep 17 00:00:00 2001
From: winterspringsummer
Date: Thu, 20 Oct 2022 21:31:29 +0900
Subject: Added try except to extras batch from directory
---
modules/extras.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index f9796624..0d817cf9 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -41,7 +41,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return outputs, "Please select an input directory.", ''
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
for img in image_list:
- image = Image.open(img)
+ try:
+ image = Image.open(img)
+ except Exception:
+ continue
imageArr.append(image)
imageNameArr.append(img)
else:
@@ -122,10 +125,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
- basename = ''
+ basename = None
- images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
+ images.save_image(image, path=outpath, basename='', seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=basename)
if opts.enable_pnginfo:
image.info = existing_pnginfo
--
cgit v1.2.3
From a13c3bed3cec27afe3c015d3d62db36e25b10d1f Mon Sep 17 00:00:00 2001
From: winterspringsummer
Date: Thu, 20 Oct 2022 21:43:27 +0900
Subject: Fixed path issue while extras batch processing
---
modules/extras.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index 0d817cf9..ac85142c 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -125,10 +125,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
- basename = None
+ basename = ''
- images.save_image(image, path=outpath, basename='', seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=basename)
+ images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo:
image.info = existing_pnginfo
--
cgit v1.2.3
From 9d71eef02e7395e179b8d5e61e6d91ddd8928d2e Mon Sep 17 00:00:00 2001
From: winterspringsummer
Date: Fri, 21 Oct 2022 09:23:13 +0900
Subject: sort file list in alphabetical ordering in extras
---
modules/extras.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/extras.py b/modules/extras.py
index ac85142c..22c5a1c1 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -39,7 +39,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '':
return outputs, "Please select an input directory.", ''
- image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
+ image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
for img in image_list:
try:
image = Image.open(img)
--
cgit v1.2.3
From c23f666dba2b484d521d2dc4be91cf9e09312647 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 09:47:43 +0300
Subject: a more strict check for activation type and a more reasonable check
for type of layer in hypernets
---
modules/hypernetworks/hypernetwork.py | 12 +++++++++---
1 file changed, 9 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 7d617680..84e7e350 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -32,10 +32,16 @@ class HypernetworkModule(torch.nn.Module):
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+
if activation_func == "relu":
linears.append(torch.nn.ReLU())
- if activation_func == "leakyrelu":
+ elif activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU())
+ elif activation_func == 'linear' or activation_func is None:
+ pass
+ else:
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if type(layer) == torch.nn.Linear:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
@@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if type(layer) == torch.nn.Linear:
layer_structure += [layer.weight, layer.bias]
return layer_structure
--
cgit v1.2.3
From 7157e5d064741fa57ca81a2c6432a651f21ee82f Mon Sep 17 00:00:00 2001
From: Patryk Wychowaniec
Date: Thu, 20 Oct 2022 19:22:59 +0200
Subject: interrogate: Fix CLIP-interrogation on CPU
Currently, trying to perform CLIP interrogation on a CPU fails, saying:
```
RuntimeError: "slow_conv2d_cpu" not implemented for 'Half'
```
This merge request fixes this issue by detecting whether the target
device is CPU and, if so, force-enabling `--no-half` and passing
`device="cpu"` to `clip.load()` (which then does some extra tricks to
ensure it works correctly on CPU).
---
modules/interrogate.py | 12 +++++++++---
1 file changed, 9 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 64b91eb4..65b05d34 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -28,9 +28,11 @@ class InterrogateModels:
clip_preprocess = None
categories = None
dtype = None
+ running_on_cpu = None
def __init__(self, content_dir):
self.categories = []
+ self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir):
for filename in os.listdir(content_dir):
@@ -53,7 +55,11 @@ class InterrogateModels:
def load_clip_model(self):
import clip
- model, preprocess = clip.load(clip_model_name)
+ if self.running_on_cpu:
+ model, preprocess = clip.load(clip_model_name, device="cpu")
+ else:
+ model, preprocess = clip.load(clip_model_name)
+
model.eval()
model = model.to(devices.device_interrogate)
@@ -62,14 +68,14 @@ class InterrogateModels:
def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()
- if not shared.cmd_opts.no_half:
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
- if not shared.cmd_opts.no_half:
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(devices.device_interrogate)
--
cgit v1.2.3
From b69c37d25e4ffc56e8f8c247fa2c38b4648cefb7 Mon Sep 17 00:00:00 2001
From: guaneec
Date: Thu, 20 Oct 2022 22:21:12 +0800
Subject: Allow datasets with only 1 image in TI
---
modules/textual_inversion/dataset.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 23bb4b6a..5b1c5002 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
self.dataset.append(entry)
- assert len(self.dataset) > 1, "No images have been found in the dataset."
+ assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset))
@@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
self.shuffle()
def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
def create_text(self, filename_text):
text = random.choice(self.lines)
--
cgit v1.2.3
From 5245c7a4935f67b677da0f5a1fc2b74c074aa0e2 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Wed, 19 Oct 2022 12:21:32 -0700
Subject: Issue #2921-Give PNG info to Hypernet previews.
---
modules/hypernetworks/hypernetwork.py | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 84e7e350..68c8f26d 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -256,6 +256,9 @@ def stack_conds(conds):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ # images is required here to give training previews their infotext. Importing this at the very top causes a circular dependency.
+ from modules import images
+
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -298,6 +301,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
last_saved_file = ""
last_saved_image = ""
+ forced_filename = ""
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
@@ -345,7 +349,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+ forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
@@ -381,7 +386,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None:
shared.state.current_image = image
- image.save(last_saved_image)
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
--
cgit v1.2.3
From 6014fb8afbe05c8d02fffe7a36a2e48128713bd2 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Wed, 19 Oct 2022 12:22:23 -0700
Subject: Do nothing if image file already exists.
---
modules/images.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index b9589563..550e53ae 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -416,7 +416,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
path = os.path.join(path, dirname)
- os.makedirs(path, exist_ok=True)
+ try:
+ os.makedirs(path, exist_ok=True)
+ except FileExistsError:
+ # If the file already exists, continue and allow said file to be overwritten.
+ pass
if forced_filename is None:
basecount = get_next_sequence_number(path, basename)
--
cgit v1.2.3
From 4ff274e1e35bb642687253ce744d2cfa738ab293 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Wed, 19 Oct 2022 12:32:22 -0700
Subject: Revise comments.
---
modules/hypernetworks/hypernetwork.py | 2 +-
modules/images.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 68c8f26d..3f96361c 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -256,7 +256,7 @@ def stack_conds(conds):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- # images is required here to give training previews their infotext. Importing this at the very top causes a circular dependency.
+ # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
assert hypernetwork_name, 'hypernetwork not selected'
diff --git a/modules/images.py b/modules/images.py
index 550e53ae..b8834e3c 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -419,7 +419,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
try:
os.makedirs(path, exist_ok=True)
except FileExistsError:
- # If the file already exists, continue and allow said file to be overwritten.
+ # If the file already exists, allow said file to be overwritten.
pass
if forced_filename is None:
--
cgit v1.2.3
From 2273e752fb3e578f1047f6d38b96330b07bf61a9 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Wed, 19 Oct 2022 14:23:48 -0700
Subject: Remove redundant try/except.
---
modules/images.py | 6 +-----
1 file changed, 1 insertion(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/images.py b/modules/images.py
index b8834e3c..b9589563 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -416,11 +416,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
path = os.path.join(path, dirname)
- try:
- os.makedirs(path, exist_ok=True)
- except FileExistsError:
- # If the file already exists, allow said file to be overwritten.
- pass
+ os.makedirs(path, exist_ok=True)
if forced_filename is None:
basecount = get_next_sequence_number(path, basename)
--
cgit v1.2.3
From 03a1e288c4973dd2dff57a97469b40f146b6fccf Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 10:13:24 +0300
Subject: turns out LayerNorm also has weight and bias and needs to be
pre-multiplied and trained for hypernets
---
modules/hypernetworks/hypernetwork.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3274a802..b1a5d0c7 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- if type(layer) == torch.nn.Linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
@@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- if type(layer) == torch.nn.Linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer_structure += [layer.weight, layer.bias]
return layer_structure
--
cgit v1.2.3
From bf30673f5132c8f28357b31224c54331e788d3e7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 10:19:25 +0300
Subject: Fix Hypernet infotext string split bug for PR #3283
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 21786968..d1deffa9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else os.path.splitext(os.path.basename(shared.loaded_hypernetwork.filename))[0]),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
--
cgit v1.2.3
From df5706409386cc2e88718bd9101045587c39f8bb Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 16:10:51 +0300
Subject: do not load aesthetic clip model until it's needed add refresh button
for aesthetic embeddings add aesthetic params to images' infotext
---
modules/aesthetic_clip.py | 40 +++++++++++++++++++----
modules/generation_parameters_copypaste.py | 18 +++++++++--
modules/img2img.py | 5 +--
modules/processing.py | 4 +--
modules/sd_models.py | 3 --
modules/txt2img.py | 4 +--
modules/ui.py | 52 ++++++++++++++++++++----------
7 files changed, 88 insertions(+), 38 deletions(-)
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
index 34efa931..8c828541 100644
--- a/modules/aesthetic_clip.py
+++ b/modules/aesthetic_clip.py
@@ -40,6 +40,8 @@ def iter_to_batched(iterable, n=1):
def create_ui():
+ import modules.ui
+
with gr.Group():
with gr.Accordion("Open for Clip Aesthetic!", open=False):
with gr.Row():
@@ -55,6 +57,8 @@ def create_ui():
label="Aesthetic imgs embedding",
value="None")
+ modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
+
with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
placeholder="This text is used to rotate the feature space of the imgs embs",
@@ -66,11 +70,21 @@ def create_ui():
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
+aesthetic_clip_model = None
+
+
+def aesthetic_clip():
+ global aesthetic_clip_model
+
+ if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
+ aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
+ aesthetic_clip_model.cpu()
+
+ return aesthetic_clip_model
+
+
def generate_imgs_embd(name, folder, batch_size):
- # clipModel = CLIPModel.from_pretrained(
- # shared.sd_model.cond_stage_model.clipModel.name_or_path
- # )
- model = shared.clip_model.to(device)
+ model = aesthetic_clip().to(device)
processor = CLIPProcessor.from_pretrained(model.name_or_path)
with torch.no_grad():
@@ -91,7 +105,7 @@ def generate_imgs_embd(name, folder, batch_size):
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
torch.save(embs, path)
- model = model.cpu()
+ model.cpu()
del processor
del embs
gc.collect()
@@ -132,7 +146,7 @@ class AestheticCLIP:
self.image_embs = None
self.load_image_embs(None)
- def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
+ def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
aesthetic_slerp=True, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False):
@@ -145,6 +159,18 @@ class AestheticCLIP:
self.aesthetic_steps = aesthetic_steps
self.load_image_embs(image_embs_name)
+ if self.image_embs_name is not None:
+ p.extra_generation_params.update({
+ "Aesthetic LR": aesthetic_lr,
+ "Aesthetic weight": aesthetic_weight,
+ "Aesthetic steps": aesthetic_steps,
+ "Aesthetic embedding": self.image_embs_name,
+ "Aesthetic slerp": aesthetic_slerp,
+ "Aesthetic text": aesthetic_imgs_text,
+ "Aesthetic text negative": aesthetic_text_negative,
+ "Aesthetic slerp angle": aesthetic_slerp_angle,
+ })
+
def set_skip(self, skip):
self.skip = skip
@@ -168,7 +194,7 @@ class AestheticCLIP:
tokens = torch.asarray(remade_batch_tokens).to(device)
- model = copy.deepcopy(shared.clip_model).to(device)
+ model = copy.deepcopy(aesthetic_clip()).to(device)
model.requires_grad_(True)
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
text_embs_2 = model.get_text_features(
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 0f041449..f73647da 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -4,13 +4,22 @@ import gradio as gr
from modules.shared import script_path
from modules import shared
-re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
+re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update())
+def quote(text):
+ if ',' not in str(text):
+ return text
+
+ text = str(text)
+ text = text.replace('\\', '\\\\')
+ text = text.replace('"', '\\"')
+ return f'"{text}"'
+
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
else:
try:
valtype = type(output.value)
- val = valtype(v)
+
+ if valtype == bool and v == "False":
+ val = False
+ else:
+ val = valtype(v)
+
res.append(gr.update(value=val))
except Exception:
res.append(gr.update())
diff --git a/modules/img2img.py b/modules/img2img.py
index bc7c66bc..eea5199b 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -109,10 +109,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert,
)
- shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
- aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text,
- aesthetic_slerp_angle,
- aesthetic_text_negative)
+ shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/processing.py b/modules/processing.py
index d1deffa9..f0852cd5 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -12,7 +12,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -318,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params)
- generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
+ generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 05a1df28..b1c91b0d 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -234,9 +234,6 @@ def load_model(checkpoint_info=None):
sd_hijack.model_hijack.hijack(sd_model)
- if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path:
- shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path)
-
sd_model.eval()
print(f"Model loaded.")
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 32ed1d8d..1761cfa2 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -36,9 +36,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None,
)
- shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
- aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle,
- aesthetic_text_negative)
+ shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/ui.py b/modules/ui.py
index 381ca925..0d020de6 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -597,27 +597,29 @@ def apply_setting(key, value):
return value
-def create_ui(wrap_gradio_gpu_call):
- import modules.img2img
- import modules.txt2img
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
- def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
- for k, v in args.items():
- setattr(refresh_component, k, v)
+ return gr.update(**(args or {}))
- return gr.update(**(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
+
+def create_ui(wrap_gradio_gpu_call):
+ import modules.img2img
+ import modules.txt2img
- refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn = refresh,
- inputs = [],
- outputs = [refresh_component]
- )
- return refresh_button
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
@@ -802,6 +804,14 @@ def create_ui(wrap_gradio_gpu_call):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
+ (aesthetic_lr, "Aesthetic LR"),
+ (aesthetic_weight, "Aesthetic weight"),
+ (aesthetic_steps, "Aesthetic steps"),
+ (aesthetic_imgs, "Aesthetic embedding"),
+ (aesthetic_slerp, "Aesthetic slerp"),
+ (aesthetic_imgs_text, "Aesthetic text"),
+ (aesthetic_text_negative, "Aesthetic text negative"),
+ (aesthetic_slerp_angle, "Aesthetic slerp angle"),
]
txt2img_preview_params = [
@@ -1077,6 +1087,14 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
+ (aesthetic_lr_im, "Aesthetic LR"),
+ (aesthetic_weight_im, "Aesthetic weight"),
+ (aesthetic_steps_im, "Aesthetic steps"),
+ (aesthetic_imgs_im, "Aesthetic embedding"),
+ (aesthetic_slerp_im, "Aesthetic slerp"),
+ (aesthetic_imgs_text_im, "Aesthetic text"),
+ (aesthetic_text_negative_im, "Aesthetic text negative"),
+ (aesthetic_slerp_angle_im, "Aesthetic slerp angle"),
]
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
--
cgit v1.2.3
From 9286fe53de2eef91f13cc3ad5938ddf67ecc8413 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 16:38:06 +0300
Subject: make aestetic embedding ciompatible with prompts longer than 75
tokens
---
modules/sd_hijack.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 36198a3c..1f8587d1 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -332,8 +332,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
+ z1 = shared.aesthetic_clip(z1, remade_batch_tokens)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
- z = shared.aesthetic_clip(z, remade_batch_tokens)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
--
cgit v1.2.3
From d0ea471b0cdaede163c6e7f6fae8535f5c3cd226 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Fri, 21 Oct 2022 14:04:41 +0100
Subject: Use opts in textual_inversion image_embedding.py for dynamic fonts
---
modules/textual_inversion/image_embedding.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index 898ce3b3..c50b1e7b 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -5,6 +5,7 @@ import zlib
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
from fonts.ttf import Roboto
import torch
+from modules.shared import opts
class EmbeddingEncoder(json.JSONEncoder):
--
cgit v1.2.3
From 306e2ff6ab8f4c7e94ab55f4f08ab8f94d73d287 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Fri, 21 Oct 2022 14:47:21 +0100
Subject: Update image_embedding.py
---
modules/textual_inversion/image_embedding.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index c50b1e7b..ea653806 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -134,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
from math import cos
image = srcimage.copy()
-
+ fontsize = 32
if textfont is None:
try:
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
@@ -151,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image)
- fontsize = 32
+
font = ImageFont.truetype(textfont, fontsize)
padding = 10
--
cgit v1.2.3
From 51e3dc9ccad157d7161b697a246e26c868d46a7c Mon Sep 17 00:00:00 2001
From: timntorres
Date: Fri, 21 Oct 2022 02:11:12 -0700
Subject: Sanitize hypernet name input.
---
modules/hypernetworks/ui.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 266f04f6..e6f50a1f 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -11,6 +11,9 @@ from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
--
cgit v1.2.3
From 19818f023cfafc472c6c241cab0b72896a168481 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Fri, 21 Oct 2022 02:14:02 -0700
Subject: Match hypernet name with filename in all cases.
---
modules/hypernetworks/hypernetwork.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b1a5d0c7..6d392be4 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -340,7 +340,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
pbar.set_description(f"loss: {mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
+ temp = hypernetwork.name
+ # Before saving, change name to match current checkpoint.
+ hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
@@ -405,6 +408,9 @@ Last saved image: {html.escape(last_saved_image)}
hypernetwork.sd_checkpoint = checkpoint.hash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
+ hypernetwork.name = hypernetwork_name
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(filename)
return hypernetwork, filename
--
cgit v1.2.3
From fccad18a59e3c2c33fefbbb1763c6a87a3a68eba Mon Sep 17 00:00:00 2001
From: timntorres
Date: Fri, 21 Oct 2022 02:17:26 -0700
Subject: Refer to Hypernet's name, sensibly, by its name variable.
---
modules/processing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index f0852cd5..ff1ec4c9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else os.path.splitext(os.path.basename(shared.loaded_hypernetwork.filename))[0]),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
--
cgit v1.2.3
From 272fa527bbe93143668ffc16838107b7dca35b40 Mon Sep 17 00:00:00 2001
From: timntorres
Date: Fri, 21 Oct 2022 02:41:55 -0700
Subject: Remove unused variable.
---
modules/hypernetworks/hypernetwork.py | 1 -
1 file changed, 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 6d392be4..47d91ea5 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -340,7 +340,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
pbar.set_description(f"loss: {mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
- temp = hypernetwork.name
# Before saving, change name to match current checkpoint.
hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
--
cgit v1.2.3
From 02e4d4694dd9254a6ca9f05c2eb7b01ea508abc7 Mon Sep 17 00:00:00 2001
From: Rcmcpe
Date: Fri, 21 Oct 2022 15:53:35 +0800
Subject: Change option description of unload_models_when_training
---
modules/shared.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 5c675b80..41d7f08e 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -266,7 +266,7 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('training', "Training"), {
- "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
+ "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
--
cgit v1.2.3
From 704036ff07b71bf86cadcbbff2bcfeebdd1ed3a6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 17:11:42 +0300
Subject: make aspect ratio overlay work regardless of selected localization
---
modules/ui.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 0d020de6..85f95792 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -879,8 +879,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
--
cgit v1.2.3
From ac0aa2b18efeeb9220a5994c8dd54c7cdda7cc40 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 17:35:51 +0300
Subject: loading SD VAE, see PR #3303
---
modules/sd_models.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index b1c91b0d..d99dbce8 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -155,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
+vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
+
+
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
@@ -186,7 +189,7 @@ def load_model_weights(model, checkpoint_info):
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
+ vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
--
cgit v1.2.3
From f49c08ea566385db339c6628f65c3a121033f67c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 21 Oct 2022 18:46:02 +0300
Subject: prevent error spam when processing images without txt files for
captions
---
modules/textual_inversion/preprocess.py | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
(limited to 'modules')
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 17e4ddc1..33eaddb6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -122,11 +122,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
continue
existing_caption = None
-
- try:
- existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
- except Exception as e:
- print(e)
+ existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
+ if os.path.exists(existing_caption_filename):
+ with open(existing_caption_filename, 'r', encoding="utf8") as file:
+ existing_caption = file.read()
if shared.state.interrupted:
break
--
cgit v1.2.3
From 57eb54b838faa383c10079e1bb5471b7bee6a695 Mon Sep 17 00:00:00 2001
From: Extraltodeus
Date: Sat, 22 Oct 2022 00:11:07 +0200
Subject: implement CUDA device selection by ID
---
modules/devices.py | 21 ++++++++++++++++++---
1 file changed, 18 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/devices.py b/modules/devices.py
index eb422583..8a159282 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,7 +1,6 @@
+import sys, os, shlex
import contextlib
-
import torch
-
from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
@@ -9,10 +8,26 @@ has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
+def extract_device_id(args, name):
+ for x in range(len(args)):
+ if name in args[x]: return args[x+1]
+ return None
def get_optimal_device():
if torch.cuda.is_available():
- return torch.device("cuda")
+ # CUDA device selection support:
+ if "shared" not in sys.modules:
+ commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop.
+ sys.argv += shlex.split(commandline_args)
+ device_id = extract_device_id(sys.argv, '--device-id')
+ else:
+ device_id = shared.cmd_opts.device_id
+
+ if device_id is not None:
+ cuda_device = f"cuda:{device_id}"
+ return torch.device(cuda_device)
+ else:
+ return torch.device("cuda")
if has_mps:
return torch.device("mps")
--
cgit v1.2.3
From 29bfacd63cb5c73b9643d94f255cca818fd49d9c Mon Sep 17 00:00:00 2001
From: Extraltodeus
Date: Sat, 22 Oct 2022 00:12:46 +0200
Subject: implement CUDA device selection, --device-id arg
---
modules/shared.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 41d7f08e..03032a47 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -80,6 +80,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
+parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
cmd_opts = parser.parse_args()
restricted_opts = [
--
cgit v1.2.3
From 2b91251637078e04472c91a06a8d9c4db9c1dcf0 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 12:23:45 +0300
Subject: removed aesthetic gradients as built-in added support for extensions
---
modules/aesthetic_clip.py | 241 --------------------------------------------
modules/images_history.py | 2 +-
modules/img2img.py | 5 +-
modules/processing.py | 35 ++++---
modules/script_callbacks.py | 42 ++++++++
modules/scripts.py | 210 ++++++++++++++++++++++++++++----------
modules/sd_hijack.py | 1 -
modules/sd_models.py | 7 +-
modules/shared.py | 19 ----
modules/txt2img.py | 5 +-
modules/ui.py | 83 +++------------
11 files changed, 244 insertions(+), 406 deletions(-)
delete mode 100644 modules/aesthetic_clip.py
create mode 100644 modules/script_callbacks.py
(limited to 'modules')
diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py
deleted file mode 100644
index 8c828541..00000000
--- a/modules/aesthetic_clip.py
+++ /dev/null
@@ -1,241 +0,0 @@
-import copy
-import itertools
-import os
-from pathlib import Path
-import html
-import gc
-
-import gradio as gr
-import torch
-from PIL import Image
-from torch import optim
-
-from modules import shared
-from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
-from tqdm.auto import tqdm, trange
-from modules.shared import opts, device
-
-
-def get_all_images_in_folder(folder):
- return [os.path.join(folder, f) for f in os.listdir(folder) if
- os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
-
-
-def check_is_valid_image_file(filename):
- return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))
-
-
-def batched(dataset, total, n=1):
- for ndx in range(0, total, n):
- yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
-
-
-def iter_to_batched(iterable, n=1):
- it = iter(iterable)
- while True:
- chunk = tuple(itertools.islice(it, n))
- if not chunk:
- return
- yield chunk
-
-
-def create_ui():
- import modules.ui
-
- with gr.Group():
- with gr.Accordion("Open for Clip Aesthetic!", open=False):
- with gr.Row():
- aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
- value=0.9)
- aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
-
- with gr.Row():
- aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
- placeholder="Aesthetic learning rate", value="0.0001")
- aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
- aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
- label="Aesthetic imgs embedding",
- value="None")
-
- modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
-
- with gr.Row():
- aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
- placeholder="This text is used to rotate the feature space of the imgs embs",
- value="")
- aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
- value=0.1)
- aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
-
- return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
-
-
-aesthetic_clip_model = None
-
-
-def aesthetic_clip():
- global aesthetic_clip_model
-
- if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
- aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
- aesthetic_clip_model.cpu()
-
- return aesthetic_clip_model
-
-
-def generate_imgs_embd(name, folder, batch_size):
- model = aesthetic_clip().to(device)
- processor = CLIPProcessor.from_pretrained(model.name_or_path)
-
- with torch.no_grad():
- embs = []
- for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
- desc=f"Generating embeddings for {name}"):
- if shared.state.interrupted:
- break
- inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
- outputs = model.get_image_features(**inputs).cpu()
- embs.append(torch.clone(outputs))
- inputs.to("cpu")
- del inputs, outputs
-
- embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
-
- # The generated embedding will be located here
- path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
- torch.save(embs, path)
-
- model.cpu()
- del processor
- del embs
- gc.collect()
- torch.cuda.empty_cache()
- res = f"""
- Done generating embedding for {name}!
- Aesthetic embedding saved to {html.escape(path)}
- """
- shared.update_aesthetic_embeddings()
- return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
- value="None"), \
- gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
- label="Imgs embedding",
- value="None"), res, ""
-
-
-def slerp(low, high, val):
- low_norm = low / torch.norm(low, dim=1, keepdim=True)
- high_norm = high / torch.norm(high, dim=1, keepdim=True)
- omega = torch.acos((low_norm * high_norm).sum(1))
- so = torch.sin(omega)
- res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
- return res
-
-
-class AestheticCLIP:
- def __init__(self):
- self.skip = False
- self.aesthetic_steps = 0
- self.aesthetic_weight = 0
- self.aesthetic_lr = 0
- self.slerp = False
- self.aesthetic_text_negative = ""
- self.aesthetic_slerp_angle = 0
- self.aesthetic_imgs_text = ""
-
- self.image_embs_name = None
- self.image_embs = None
- self.load_image_embs(None)
-
- def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
- aesthetic_slerp=True, aesthetic_imgs_text="",
- aesthetic_slerp_angle=0.15,
- aesthetic_text_negative=False):
- self.aesthetic_imgs_text = aesthetic_imgs_text
- self.aesthetic_slerp_angle = aesthetic_slerp_angle
- self.aesthetic_text_negative = aesthetic_text_negative
- self.slerp = aesthetic_slerp
- self.aesthetic_lr = aesthetic_lr
- self.aesthetic_weight = aesthetic_weight
- self.aesthetic_steps = aesthetic_steps
- self.load_image_embs(image_embs_name)
-
- if self.image_embs_name is not None:
- p.extra_generation_params.update({
- "Aesthetic LR": aesthetic_lr,
- "Aesthetic weight": aesthetic_weight,
- "Aesthetic steps": aesthetic_steps,
- "Aesthetic embedding": self.image_embs_name,
- "Aesthetic slerp": aesthetic_slerp,
- "Aesthetic text": aesthetic_imgs_text,
- "Aesthetic text negative": aesthetic_text_negative,
- "Aesthetic slerp angle": aesthetic_slerp_angle,
- })
-
- def set_skip(self, skip):
- self.skip = skip
-
- def load_image_embs(self, image_embs_name):
- if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
- image_embs_name = None
- self.image_embs_name = None
- if image_embs_name is not None and self.image_embs_name != image_embs_name:
- self.image_embs_name = image_embs_name
- self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
- self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
- self.image_embs.requires_grad_(False)
-
- def __call__(self, z, remade_batch_tokens):
- if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
- tokenizer = shared.sd_model.cond_stage_model.tokenizer
- if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [
- [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
- remade_batch_tokens]
-
- tokens = torch.asarray(remade_batch_tokens).to(device)
-
- model = copy.deepcopy(aesthetic_clip()).to(device)
- model.requires_grad_(True)
- if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
- text_embs_2 = model.get_text_features(
- **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
- if self.aesthetic_text_negative:
- text_embs_2 = self.image_embs - text_embs_2
- text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
- img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
- else:
- img_embs = self.image_embs
-
- with torch.enable_grad():
-
- # We optimize the model to maximize the similarity
- optimizer = optim.Adam(
- model.text_model.parameters(), lr=self.aesthetic_lr
- )
-
- for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
- text_embs = model.get_text_features(input_ids=tokens)
- text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
- sim = text_embs @ img_embs.T
- loss = -sim
- optimizer.zero_grad()
- loss.mean().backward()
- optimizer.step()
-
- zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
- if opts.CLIP_stop_at_last_layers > 1:
- zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
- zn = model.text_model.final_layer_norm(zn)
- else:
- zn = zn.last_hidden_state
- model.cpu()
- del model
- gc.collect()
- torch.cuda.empty_cache()
- zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
- if self.slerp:
- z = slerp(z, zn, self.aesthetic_weight)
- else:
- z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
-
- return z
diff --git a/modules/images_history.py b/modules/images_history.py
index 78fd0543..bc5cf11f 100644
--- a/modules/images_history.py
+++ b/modules/images_history.py
@@ -310,7 +310,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
forward = gr.Button('Prev batch')
backward = gr.Button('Next batch')
with gr.Column(scale=3):
- load_info = gr.HTML(visible=not custom_dir)
+ load_info = gr.HTML(visible=not custom_dir)
with gr.Row(visible=False) as warning:
warning_box = gr.Textbox("Message", interactive=False)
diff --git a/modules/img2img.py b/modules/img2img.py
index eea5199b..8d9f7cf9 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_inpaint = mode == 1
is_batch = mode == 2
@@ -109,7 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert,
)
- shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/processing.py b/modules/processing.py
index ff1ec4c9..372489f7 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -104,6 +104,12 @@ class StableDiffusionProcessing():
self.seed_resize_from_h = 0
self.seed_resize_from_w = 0
+ self.scripts = None
+ self.script_args = None
+ self.all_prompts = None
+ self.all_seeds = None
+ self.all_subseeds = None
+
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -350,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list:
- all_prompts = p.prompt
+ p.all_prompts = p.prompt
else:
- all_prompts = p.batch_size * p.n_iter * [p.prompt]
+ p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
if type(seed) == list:
- all_seeds = seed
+ p.all_seeds = seed
else:
- all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
if type(subseed) == list:
- all_subseeds = subseed
+ p.all_subseeds = subseed
else:
- all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
+ if p.scripts is not None:
+ p.scripts.run_alwayson_scripts(p)
+
infotexts = []
output_images = []
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
- p.init(all_prompts, all_seeds, all_subseeds)
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
if state.job_count == -1:
state.job_count = p.n_iter
@@ -387,9 +396,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
- prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+ subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if (len(prompts) == 0):
break
@@ -490,10 +499,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+ return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
new file mode 100644
index 00000000..866b7acd
--- /dev/null
+++ b/modules/script_callbacks.py
@@ -0,0 +1,42 @@
+
+callbacks_model_loaded = []
+callbacks_ui_tabs = []
+
+
+def clear_callbacks():
+ callbacks_model_loaded.clear()
+ callbacks_ui_tabs.clear()
+
+
+def model_loaded_callback(sd_model):
+ for callback in callbacks_model_loaded:
+ callback(sd_model)
+
+
+def ui_tabs_callback():
+ res = []
+
+ for callback in callbacks_ui_tabs:
+ res += callback() or []
+
+ return res
+
+
+def on_model_loaded(callback):
+ """register a function to be called when the stable diffusion model is created; the model is
+ passed as an argument"""
+ callbacks_model_loaded.append(callback)
+
+
+def on_ui_tabs(callback):
+ """register a function to be called when the UI is creating new tabs.
+ The function must either return a None, which means no new tabs to be added, or a list, where
+ each element is a tuple:
+ (gradio_component, title, elem_id)
+
+ gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
+ title is tab text displayed to user in the UI
+ elem_id is HTML id for the tab
+ """
+ callbacks_ui_tabs.append(callback)
+
diff --git a/modules/scripts.py b/modules/scripts.py
index 1039fa9c..65f25f49 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,86 +1,153 @@
import os
import sys
import traceback
+from collections import namedtuple
import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared
+from modules import shared, paths, script_callbacks
+
+AlwaysVisible = object()
+
class Script:
filename = None
args_from = None
args_to = None
+ alwayson = False
+
+ infotext_fields = None
+ """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
+ parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
+ """
- # The title of the script. This is what will be displayed in the dropdown menu.
def title(self):
+ """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
+
raise NotImplementedError()
- # How the script is displayed in the UI. See https://gradio.app/docs/#components
- # for the different UI components you can use and how to create them.
- # Most UI components can return a value, such as a boolean for a checkbox.
- # The returned values are passed to the run method as parameters.
def ui(self, is_img2img):
+ """this function should create gradio UI elements. See https://gradio.app/docs/#components
+ The return value should be an array of all components that are used in processing.
+ Values of those returned componenbts will be passed to run() and process() functions.
+ """
+
pass
- # Determines when the script should be shown in the dropdown menu via the
- # returned value. As an example:
- # is_img2img is True if the current tab is img2img, and False if it is txt2img.
- # Thus, return is_img2img to only show the script on the img2img tab.
def show(self, is_img2img):
+ """
+ is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
+
+ This function should return:
+ - False if the script should not be shown in UI at all
+ - True if the script should be shown in UI if it's scelected in the scripts drowpdown
+ - script.AlwaysVisible if the script should be shown in UI at all times
+ """
+
return True
- # This is where the additional processing is implemented. The parameters include
- # self, the model object "p" (a StableDiffusionProcessing class, see
- # processing.py), and the parameters returned by the ui method.
- # Custom functions can be defined here, and additional libraries can be imported
- # to be used in processing. The return value should be a Processed object, which is
- # what is returned by the process_images method.
- def run(self, *args):
+ def run(self, p, *args):
+ """
+ This function is called if the script has been selected in the script dropdown.
+ It must do all processing and return the Processed object with results, same as
+ one returned by processing.process_images.
+
+ Usually the processing is done by calling the processing.process_images function.
+
+ args contains all values returned by components from ui()
+ """
+
raise NotImplementedError()
- # The description method is currently unused.
- # To add a description that appears when hovering over the title, amend the "titles"
- # dict in script.js to include the script title (returned by title) as a key, and
- # your description as the value.
+ def process(self, p, *args):
+ """
+ This function is called before processing begins for AlwaysVisible scripts.
+ scripts. You can modify the processing object (p) here, inject hooks, etc.
+ """
+
+ pass
+
def describe(self):
+ """unused"""
return ""
+current_basedir = paths.script_path
+
+
+def basedir():
+ """returns the base directory for the current script. For scripts in the main scripts directory,
+ this is the main directory (where webui.py resides), and for scripts in extensions directory
+ (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
+ """
+ return current_basedir
+
+
scripts_data = []
+ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
+ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
+
+
+def list_scripts(scriptdirname, extension):
+ scripts_list = []
+
+ basedir = os.path.join(paths.script_path, scriptdirname)
+ if os.path.exists(basedir):
+ for filename in sorted(os.listdir(basedir)):
+ scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
+
+ extdir = os.path.join(paths.script_path, "extensions")
+ if os.path.exists(extdir):
+ for dirname in sorted(os.listdir(extdir)):
+ dirpath = os.path.join(extdir, dirname)
+ if not os.path.isdir(dirpath):
+ continue
+ for filename in sorted(os.listdir(os.path.join(dirpath, scriptdirname))):
+ scripts_list.append(ScriptFile(dirpath, filename, os.path.join(dirpath, scriptdirname, filename)))
-def load_scripts(basedir):
- if not os.path.exists(basedir):
- return
+ scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
- for filename in sorted(os.listdir(basedir)):
- path = os.path.join(basedir, filename)
+ return scripts_list
- if os.path.splitext(path)[1].lower() != '.py':
- continue
- if not os.path.isfile(path):
- continue
+def load_scripts():
+ global current_basedir
+ scripts_data.clear()
+ script_callbacks.clear_callbacks()
+
+ scripts_list = list_scripts("scripts", ".py")
+
+ syspath = sys.path
+ for scriptfile in sorted(scripts_list):
try:
- with open(path, "r", encoding="utf8") as file:
+ if scriptfile.basedir != paths.script_path:
+ sys.path = [scriptfile.basedir] + sys.path
+ current_basedir = scriptfile.basedir
+
+ with open(scriptfile.path, "r", encoding="utf8") as file:
text = file.read()
from types import ModuleType
- compiled = compile(text, path, 'exec')
- module = ModuleType(filename)
+ compiled = compile(text, scriptfile.path, 'exec')
+ module = ModuleType(scriptfile.filename)
exec(compiled, module.__dict__)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
- scripts_data.append((script_class, path))
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
except Exception:
- print(f"Error loading script: {filename}", file=sys.stderr)
+ print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ finally:
+ sys.path = syspath
+ current_basedir = paths.script_path
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
@@ -96,56 +163,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptRunner:
def __init__(self):
self.scripts = []
+ self.selectable_scripts = []
+ self.alwayson_scripts = []
self.titles = []
+ self.infotext_fields = []
def setup_ui(self, is_img2img):
- for script_class, path in scripts_data:
+ for script_class, path, basedir in scripts_data:
script = script_class()
script.filename = path
- if not script.show(is_img2img):
- continue
+ visibility = script.show(is_img2img)
- self.scripts.append(script)
+ if visibility == AlwaysVisible:
+ self.scripts.append(script)
+ self.alwayson_scripts.append(script)
+ script.alwayson = True
- self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
+ elif visibility:
+ self.scripts.append(script)
+ self.selectable_scripts.append(script)
- dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
- dropdown.save_to_config = True
- inputs = [dropdown]
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
+
+ inputs = [None]
+ inputs_alwayson = [True]
- for script in self.scripts:
+ def create_script_ui(script, inputs, inputs_alwayson):
script.args_from = len(inputs)
script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
if controls is None:
- continue
+ return
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
- control.visible = False
+ if not script.alwayson:
+ control.visible = False
+
+ if script.infotext_fields is not None:
+ self.infotext_fields += script.infotext_fields
inputs += controls
+ inputs_alwayson += [script.alwayson for _ in controls]
script.args_to = len(inputs)
+ for script in self.alwayson_scripts:
+ with gr.Group():
+ create_script_ui(script, inputs, inputs_alwayson)
+
+ dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
+ dropdown.save_to_config = True
+ inputs[0] = dropdown
+
+ for script in self.selectable_scripts:
+ create_script_ui(script, inputs, inputs_alwayson)
+
def select_script(script_index):
- if 0 < script_index <= len(self.scripts):
- script = self.scripts[script_index-1]
+ if 0 < script_index <= len(self.selectable_scripts):
+ script = self.selectable_scripts[script_index-1]
args_from = script.args_from
args_to = script.args_to
else:
args_from = 0
args_to = 0
- return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
+ return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
def init_field(title):
if title == 'None':
return
script_index = self.titles.index(title)
- script = self.scripts[script_index]
+ script = self.selectable_scripts[script_index]
for i in range(script.args_from, script.args_to):
inputs[i].visible = True
@@ -164,7 +255,7 @@ class ScriptRunner:
if script_index == 0:
return None
- script = self.scripts[script_index-1]
+ script = self.selectable_scripts[script_index-1]
if script is None:
return None
@@ -176,6 +267,15 @@ class ScriptRunner:
return processed
+ def run_alwayson_scripts(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.process(p, *script_args)
+ except Exception:
+ print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def reload_sources(self):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
@@ -197,19 +297,21 @@ class ScriptRunner:
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
+
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
def reload_script_body_only():
scripts_txt2img.reload_sources()
scripts_img2img.reload_sources()
-def reload_scripts(basedir):
+def reload_scripts():
global scripts_txt2img, scripts_img2img
- scripts_data.clear()
- load_scripts(basedir)
+ load_scripts()
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 1f8587d1..0f10828e 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -332,7 +332,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
- z1 = shared.aesthetic_clip(z1, remade_batch_tokens)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
diff --git a/modules/sd_models.py b/modules/sd_models.py
index d99dbce8..f9b3063d 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -7,7 +7,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices
+from modules import shared, modelloader, devices, script_callbacks
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
@@ -238,6 +238,9 @@ def load_model(checkpoint_info=None):
sd_hijack.model_hijack.hijack(sd_model)
sd_model.eval()
+ shared.sd_model = sd_model
+
+ script_callbacks.model_loaded_callback(sd_model)
print(f"Model loaded.")
return sd_model
@@ -252,7 +255,7 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
- shared.sd_model = load_model(checkpoint_info)
+ load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
diff --git a/modules/shared.py b/modules/shared.py
index 0dbe360d..7d786f07 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -31,7 +31,6 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
-parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(models_path, 'aesthetic_embeddings'), help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
@@ -109,21 +108,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
-
-os.makedirs(cmd_opts.aesthetic_embeddings_dir, exist_ok=True)
-aesthetic_embeddings = {}
-
-
-def update_aesthetic_embeddings():
- global aesthetic_embeddings
- aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
- os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
- aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings)
-
-
-update_aesthetic_embeddings()
-
-
def reload_hypernetworks():
global hypernetworks
@@ -415,9 +399,6 @@ sd_model = None
clip_model = None
-from modules.aesthetic_clip import AestheticCLIP
-aesthetic_clip = AestheticCLIP()
-
progress_print_out = sys.stdout
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 1761cfa2..c9d5a090 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -7,7 +7,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -36,7 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None,
)
- shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/ui.py b/modules/ui.py
index 70a9cf10..c977482c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -23,10 +23,10 @@ import gradio as gr
import gradio.utils
import gradio.routes
-from modules import sd_hijack, sd_models, localization
+from modules import sd_hijack, sd_models, localization, script_callbacks
from modules.paths import script_path
-from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings
+from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
@@ -44,7 +44,6 @@ from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui
-import modules.aesthetic_clip as aesthetic_clip
import modules.images_history as img_his
@@ -662,8 +661,6 @@ def create_ui(wrap_gradio_gpu_call):
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
- aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui()
-
with gr.Group():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
@@ -718,14 +715,6 @@ def create_ui(wrap_gradio_gpu_call):
denoising_strength,
firstphase_width,
firstphase_height,
- aesthetic_lr,
- aesthetic_weight,
- aesthetic_steps,
- aesthetic_imgs,
- aesthetic_slerp,
- aesthetic_imgs_text,
- aesthetic_slerp_angle,
- aesthetic_text_negative
] + custom_inputs,
outputs=[
@@ -804,14 +793,7 @@ def create_ui(wrap_gradio_gpu_call):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
- (aesthetic_lr, "Aesthetic LR"),
- (aesthetic_weight, "Aesthetic weight"),
- (aesthetic_steps, "Aesthetic steps"),
- (aesthetic_imgs, "Aesthetic embedding"),
- (aesthetic_slerp, "Aesthetic slerp"),
- (aesthetic_imgs_text, "Aesthetic text"),
- (aesthetic_text_negative, "Aesthetic text negative"),
- (aesthetic_slerp_angle, "Aesthetic slerp angle"),
+ *modules.scripts.scripts_txt2img.infotext_fields
]
txt2img_preview_params = [
@@ -896,8 +878,6 @@ def create_ui(wrap_gradio_gpu_call):
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
- aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui()
-
with gr.Group():
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
@@ -988,14 +968,6 @@ def create_ui(wrap_gradio_gpu_call):
inpainting_mask_invert,
img2img_batch_input_dir,
img2img_batch_output_dir,
- aesthetic_lr_im,
- aesthetic_weight_im,
- aesthetic_steps_im,
- aesthetic_imgs_im,
- aesthetic_slerp_im,
- aesthetic_imgs_text_im,
- aesthetic_slerp_angle_im,
- aesthetic_text_negative_im,
] + custom_inputs,
outputs=[
img2img_gallery,
@@ -1087,14 +1059,7 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
- (aesthetic_lr_im, "Aesthetic LR"),
- (aesthetic_weight_im, "Aesthetic weight"),
- (aesthetic_steps_im, "Aesthetic steps"),
- (aesthetic_imgs_im, "Aesthetic embedding"),
- (aesthetic_slerp_im, "Aesthetic slerp"),
- (aesthetic_imgs_text_im, "Aesthetic text"),
- (aesthetic_text_negative_im, "Aesthetic text negative"),
- (aesthetic_slerp_angle_im, "Aesthetic slerp angle"),
+ *modules.scripts.scripts_img2img.infotext_fields
]
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
@@ -1217,9 +1182,9 @@ def create_ui(wrap_gradio_gpu_call):
)
#images history
images_history_switch_dict = {
- "fn":modules.generation_parameters_copypaste.connect_paste,
- "t2i":txt2img_paste_fields,
- "i2i":img2img_paste_fields
+ "fn": modules.generation_parameters_copypaste.connect_paste,
+ "t2i": txt2img_paste_fields,
+ "i2i": img2img_paste_fields
}
images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
@@ -1264,18 +1229,6 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary')
- with gr.Tab(label="Create aesthetic images embedding"):
-
- new_embedding_name_ae = gr.Textbox(label="Name")
- process_src_ae = gr.Textbox(label='Source directory')
- batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
- with gr.Row():
- with gr.Column(scale=3):
- gr.HTML(value="")
-
- with gr.Column():
- create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')
-
with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
@@ -1375,21 +1328,6 @@ def create_ui(wrap_gradio_gpu_call):
]
)
- create_embedding_ae.click(
- fn=aesthetic_clip.generate_imgs_embd,
- inputs=[
- new_embedding_name_ae,
- process_src_ae,
- batch_ae
- ],
- outputs=[
- aesthetic_imgs,
- aesthetic_imgs_im,
- ti_output,
- ti_outcome,
- ]
- )
-
create_hypernetwork.click(
fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[
@@ -1580,10 +1518,10 @@ Requested path was: {f}
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
+ oldval = opts.data.get(key, None)
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
return gr.update(value=oldval), opts.dumpjson()
- oldval = opts.data.get(key, None)
opts.data[key] = value
if oldval != value:
@@ -1692,9 +1630,12 @@ Requested path was: {f}
(images_history, "Image Browser", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
- (settings_interface, "Settings", "settings"),
]
+ interfaces += script_callbacks.ui_tabs_callback()
+
+ interfaces += [(settings_interface, "Settings", "settings")]
+
with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
css = file.read()
--
cgit v1.2.3
From 6398dc9b1049f242576ca309f95a3fb1e654951c Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 13:34:49 +0300
Subject: further support for extensions
---
modules/scripts.py | 44 +++++++++++++++++++++++++++++++++++---------
modules/ui.py | 19 ++++++++++---------
2 files changed, 45 insertions(+), 18 deletions(-)
(limited to 'modules')
diff --git a/modules/scripts.py b/modules/scripts.py
index 65f25f49..9323af3e 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -102,17 +102,39 @@ def list_scripts(scriptdirname, extension):
if os.path.exists(extdir):
for dirname in sorted(os.listdir(extdir)):
dirpath = os.path.join(extdir, dirname)
- if not os.path.isdir(dirpath):
+ scriptdirpath = os.path.join(dirpath, scriptdirname)
+
+ if not os.path.isdir(scriptdirpath):
continue
- for filename in sorted(os.listdir(os.path.join(dirpath, scriptdirname))):
- scripts_list.append(ScriptFile(dirpath, filename, os.path.join(dirpath, scriptdirname, filename)))
+ for filename in sorted(os.listdir(scriptdirpath)):
+ scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
return scripts_list
+def list_files_with_name(filename):
+ res = []
+
+ dirs = [paths.script_path]
+
+ extdir = os.path.join(paths.script_path, "extensions")
+ if os.path.exists(extdir):
+ dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
+
+ for dirpath in dirs:
+ if not os.path.isdir(dirpath):
+ continue
+
+ path = os.path.join(dirpath, filename)
+ if os.path.isfile(filename):
+ res.append(path)
+
+ return res
+
+
def load_scripts():
global current_basedir
scripts_data.clear()
@@ -276,7 +298,7 @@ class ScriptRunner:
print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- def reload_sources(self):
+ def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from
@@ -286,9 +308,12 @@ class ScriptRunner:
from types import ModuleType
- compiled = compile(text, filename, 'exec')
- module = ModuleType(script.filename)
- exec(compiled, module.__dict__)
+ module = cache.get(filename, None)
+ if module is None:
+ compiled = compile(text, filename, 'exec')
+ module = ModuleType(script.filename)
+ exec(compiled, module.__dict__)
+ cache[filename] = module
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
@@ -303,8 +328,9 @@ scripts_img2img = ScriptRunner()
def reload_script_body_only():
- scripts_txt2img.reload_sources()
- scripts_img2img.reload_sources()
+ cache = {}
+ scripts_txt2img.reload_sources(cache)
+ scripts_img2img.reload_sources(cache)
def reload_scripts():
diff --git a/modules/ui.py b/modules/ui.py
index c977482c..29986124 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1636,13 +1636,15 @@ Requested path was: {f}
interfaces += [(settings_interface, "Settings", "settings")]
- with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
- css = file.read()
+ css = ""
+
+ for cssfile in modules.scripts.list_files_with_name("style.css"):
+ with open(cssfile, "r", encoding="utf8") as file:
+ css += file.read() + "\n"
if os.path.exists(os.path.join(script_path, "user.css")):
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
- usercss = file.read()
- css += usercss
+ css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
@@ -1865,9 +1867,9 @@ def load_javascript(raw_response):
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f''
- jsdir = os.path.join(script_path, "javascript")
- for filename in sorted(os.listdir(jsdir)):
- with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
+ scripts_list = modules.scripts.list_scripts("javascript", ".js")
+ for basedir, filename, path in scripts_list:
+ with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n"
if cmd_opts.theme is not None:
@@ -1885,6 +1887,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response
-reload_javascript = partial(load_javascript,
- gradio.routes.templates.TemplateResponse)
+reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
reload_javascript()
--
cgit v1.2.3
From 50b5504401e50b6c94eba41b37fe212b2f27b792 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 14:04:14 +0300
Subject: remove parsing command line from devices.py
---
modules/devices.py | 14 +++++---------
modules/lowvram.py | 9 ++++-----
2 files changed, 9 insertions(+), 14 deletions(-)
(limited to 'modules')
diff --git a/modules/devices.py b/modules/devices.py
index 8a159282..dc1f3cdd 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -15,14 +15,10 @@ def extract_device_id(args, name):
def get_optimal_device():
if torch.cuda.is_available():
- # CUDA device selection support:
- if "shared" not in sys.modules:
- commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop.
- sys.argv += shlex.split(commandline_args)
- device_id = extract_device_id(sys.argv, '--device-id')
- else:
- device_id = shared.cmd_opts.device_id
-
+ from modules import shared
+
+ device_id = shared.cmd_opts.device_id
+
if device_id is not None:
cuda_device = f"cuda:{device_id}"
return torch.device(cuda_device)
@@ -49,7 +45,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 7eba1349..f327c3df 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,9 +1,8 @@
import torch
-from modules.devices import get_optimal_device
+from modules import devices
module_in_gpu = None
cpu = torch.device("cpu")
-device = gpu = get_optimal_device()
def send_everything_to_cpu():
@@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram):
if module_in_gpu is not None:
module_in_gpu.to(cpu)
- module.to(gpu)
+ module.to(devices.device)
module_in_gpu = module
# see below for register_forward_pre_hook;
@@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
- sd_model.to(device)
+ sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models
@@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
- sd_model.model.to(device)
+ sd_model.model.to(devices.device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
--
cgit v1.2.3
From 0e8ca8e7af05be22d7d2c07a47c3c7febe0f0ab6 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Sat, 22 Oct 2022 11:07:00 +0000
Subject: add dropout
---
modules/hypernetworks/hypernetwork.py | 68 +++++++++++++++++++++--------------
modules/hypernetworks/ui.py | 10 +++---
modules/ui.py | 43 +++++++++++-----------
3 files changed, 70 insertions(+), 51 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 905cbeef..e493f366 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,47 +1,60 @@
+import csv
import datetime
import glob
import html
import os
import sys
import traceback
-import tqdm
-import csv
+import modules.textual_inversion.dataset
import torch
-
-from ldm.util import default
-from modules import devices, shared, processing, sd_models
-import torch
-from torch import einsum
+import tqdm
from einops import rearrange, repeat
-import modules.textual_inversion.dataset
+from ldm.util import default
+from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from torch import einsum
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
- activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU,
- "swish": torch.nn.Hardswish}
-
- def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
+ activation_dict = {
+ "relu": torch.nn.ReLU,
+ "leakyrelu": torch.nn.LeakyReLU,
+ "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish,
+ }
+
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
-
+ assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
+
linears = []
for i in range(len(layer_structure) - 1):
+
+ # Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
- # if skip_first_layer because first parameters potentially contain negative values
- # if i < 1: continue
- if activation_func in HypernetworkModule.activation_dict:
- linears.append(HypernetworkModule.activation_dict[activation_func]())
+
+ # Add an activation func
+ if activation_func == "linear":
+ pass
+ elif activation_func in self.activation_dict:
+ linears.append(self.activation_dict[activation_func]())
else:
- print("Invalid key {} encountered as activation function!".format(activation_func))
- # if use_dropout:
- # linears.append(torch.nn.Dropout(p=0.3))
+ raise NotImplementedError(
+ "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
+ )
+
+ # Add dropout
+ if use_dropout:
+ linears.append(torch.nn.Dropout(p=0.3))
+
+ # Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -93,7 +106,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
self.filename = None
self.name = name
self.layers = {}
@@ -101,13 +114,14 @@ class Hypernetwork:
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
- self.add_layer_norm = add_layer_norm
self.activation_func = activation_func
+ self.add_layer_norm = add_layer_norm
+ self.use_dropout = use_dropout
for size in enable_sizes or []:
self.layers[size] = (
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
- HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
)
def weights(self):
@@ -129,8 +143,9 @@ class Hypernetwork:
state_dict['step'] = self.step
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
- state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['activation_func'] = self.activation_func
+ state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -144,8 +159,9 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
- self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.activation_func = state_dict.get('activation_func', None)
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ self.use_dropout = state_dict.get('use_dropout', False)
for size, sd in state_dict.items():
if type(size) == int:
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 1a5a27d8..5f6f17b6 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -3,14 +3,13 @@ import os
import re
import gradio as gr
-
-import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared, devices
+import modules.textual_inversion.textual_inversion
+from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
+def create_hypernetwork(name, enable_sizes, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
@@ -21,8 +20,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
name=name,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
- add_layer_norm=add_layer_norm,
activation_func=activation_func,
+ add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
)
hypernet.save(fn)
diff --git a/modules/ui.py b/modules/ui.py
index 716f14b8..d4b32c05 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -5,43 +5,44 @@ import json
import math
import mimetypes
import os
+import platform
import random
+import subprocess as sp
import sys
import tempfile
import time
import traceback
-import platform
-import subprocess as sp
from functools import partial, reduce
+import gradio as gr
+import gradio.routes
+import gradio.utils
import numpy as np
+import piexif
import torch
from PIL import Image, PngImagePlugin
-import piexif
-import gradio as gr
-import gradio.utils
-import gradio.routes
-
-from modules import sd_hijack, sd_models, localization
+from modules import localization, sd_hijack, sd_models
from modules.paths import script_path
-from modules.shared import opts, cmd_opts, restricted_opts
+from modules.shared import cmd_opts, opts, restricted_opts
+
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
-import modules.shared as shared
-from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.sd_hijack import model_hijack
+
+import modules.codeformer_model
+import modules.generation_parameters_copypaste
+import modules.gfpgan_model
+import modules.hypernetworks.ui
+import modules.images_history as img_his
import modules.ldsr_model
import modules.scripts
-import modules.gfpgan_model
-import modules.codeformer_model
+import modules.shared as shared
import modules.styles
-import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
from modules import prompt_parser
from modules.images import save_image
-import modules.textual_inversion.ui
-import modules.hypernetworks.ui
-import modules.images_history as img_his
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -1223,8 +1224,9 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
with gr.Row():
with gr.Column(scale=3):
@@ -1308,8 +1310,9 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name,
new_hypernetwork_sizes,
new_hypernetwork_layer_structure,
- new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func,
+ new_hypernetwork_add_layer_norm,
+ new_hypernetwork_use_dropout
],
outputs=[
train_hypernetwork_name,
--
cgit v1.2.3
From 1cd3ed7def40198f46d30f74dd37d2906ebdbaa6 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 14:28:56 +0300
Subject: fix for extensions without style.css
---
modules/ui.py | 3 +++
1 file changed, 3 insertions(+)
(limited to 'modules')
diff --git a/modules/ui.py b/modules/ui.py
index 29986124..d8d52db1 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1639,6 +1639,9 @@ Requested path was: {f}
css = ""
for cssfile in modules.scripts.list_files_with_name("style.css"):
+ if not os.path.isfile(cssfile):
+ continue
+
with open(cssfile, "r", encoding="utf8") as file:
css += file.read() + "\n"
--
cgit v1.2.3
From 7fd90128eb6d1820045bfe2c2c1269661023a712 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 14:48:43 +0300
Subject: added a guard for hypernet training that will stop early if weights
are getting no gradients
---
modules/hypernetworks/hypernetwork.py | 11 +++++++++++
1 file changed, 11 insertions(+)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 47d91ea5..46039a49 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -310,6 +310,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ steps_without_grad = 0
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
@@ -332,8 +334,17 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item()
optimizer.zero_grad()
+ weights[0].grad = None
loss.backward()
+
+ if weights[0].grad is None:
+ steps_without_grad += 1
+ else:
+ steps_without_grad = 0
+ assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
+
optimizer.step()
+
mean_loss = losses.mean()
if torch.isnan(mean_loss):
raise RuntimeError("Loss diverged.")
--
cgit v1.2.3
From fccba4729db341a299db3343e3264fecd9459a07 Mon Sep 17 00:00:00 2001
From: discus0434
Date: Sat, 22 Oct 2022 12:02:41 +0000
Subject: add an option to avoid dying relu
---
modules/hypernetworks/hypernetwork.py | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b7a04038..3132a56c 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -32,7 +32,6 @@ class HypernetworkModule(torch.nn.Module):
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
- assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
linears = []
for i in range(len(layer_structure) - 1):
@@ -43,12 +42,13 @@ class HypernetworkModule(torch.nn.Module):
# Add an activation func
if activation_func == "linear" or activation_func is None:
pass
+ # If ReLU, Skip adding it to the first layer to avoid dying ReLU
+ elif activation_func == "relu" and i < 1:
+ pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
- raise RuntimeError(
- "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
- )
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
# Add dropout
if use_dropout:
@@ -166,8 +166,8 @@ class Hypernetwork:
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
- HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
)
self.name = state_dict.get('name', self.name)
--
cgit v1.2.3
From 7912acef725832debef58c4c7bf8ec22fb446c0b Mon Sep 17 00:00:00 2001
From: discus0434
Date: Sat, 22 Oct 2022 13:00:44 +0000
Subject: small fix
---
modules/hypernetworks/hypernetwork.py | 12 +++++-------
modules/ui.py | 1 -
2 files changed, 5 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3132a56c..7d12e0ff 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -42,22 +42,20 @@ class HypernetworkModule(torch.nn.Module):
# Add an activation func
if activation_func == "linear" or activation_func is None:
pass
- # If ReLU, Skip adding it to the first layer to avoid dying ReLU
- elif activation_func == "relu" and i < 1:
- pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
- # Add dropout
- if use_dropout:
- linears.append(torch.nn.Dropout(p=0.3))
-
# Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
+ # Add dropout
+ if use_dropout:
+ p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2
+ linears.append(torch.nn.Dropout(p=p))
+
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
diff --git a/modules/ui.py b/modules/ui.py
index cd118552..eca887ca 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1244,7 +1244,6 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
- new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row():
with gr.Column(scale=3):
--
cgit v1.2.3
From 6a4fa73a38935a18779ce1809892730fd1572bee Mon Sep 17 00:00:00 2001
From: discus0434
Date: Sat, 22 Oct 2022 13:44:39 +0000
Subject: small fix
---
modules/hypernetworks/hypernetwork.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3372aae2..3bc71ee5 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -51,10 +51,9 @@ class HypernetworkModule(torch.nn.Module):
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
- # Add dropout
- if use_dropout:
- p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2
- linears.append(torch.nn.Dropout(p=p))
+ # Add dropout expect last layer
+ if use_dropout and i < len(layer_structure) - 3:
+ linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
--
cgit v1.2.3
From d37cfffd537cd29309afbcb192c4f979995c6a34 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 19:18:56 +0300
Subject: added callback for creating new settings in extensions
---
modules/script_callbacks.py | 11 +++++++++++
modules/shared.py | 19 +++++++++++++++++--
modules/ui.py | 6 +++++-
3 files changed, 33 insertions(+), 3 deletions(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 866b7acd..1270e50f 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,6 +1,7 @@
callbacks_model_loaded = []
callbacks_ui_tabs = []
+callbacks_ui_settings = []
def clear_callbacks():
@@ -22,6 +23,11 @@ def ui_tabs_callback():
return res
+def ui_settings_callback():
+ for callback in callbacks_ui_settings:
+ callback()
+
+
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
@@ -40,3 +46,8 @@ def on_ui_tabs(callback):
"""
callbacks_ui_tabs.append(callback)
+
+def on_ui_settings(callback):
+ """register a function to be called before UI settingsare populated; add your settings
+ by using shared.opts.add_option(shared.OptionInfo(...)) """
+ callbacks_ui_settings.append(callback)
diff --git a/modules/shared.py b/modules/shared.py
index 5d83971e..d9cb65ef 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -165,13 +165,13 @@ def realesrgan_models_names():
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
- self.section = None
+ self.section = section
self.refresh = refresh
@@ -327,6 +327,7 @@ options_templates.update(options_section(('images-history', "Images Browser"), {
}))
+
class Options:
data = None
data_labels = options_templates
@@ -389,6 +390,20 @@ class Options:
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d)
+ def add_option(self, key, info):
+ self.data_labels[key] = info
+
+ def reorder(self):
+ """reorder settings so that all items related to section always go together"""
+
+ section_ids = {}
+ settings_items = self.data_labels.items()
+ for k, item in settings_items:
+ if item.section not in section_ids:
+ section_ids[item.section] = len(section_ids)
+
+ self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+
opts = Options()
if os.path.exists(config_filename):
diff --git a/modules/ui.py b/modules/ui.py
index d8d52db1..2849b111 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1461,6 +1461,9 @@ def create_ui(wrap_gradio_gpu_call):
components = []
component_dict = {}
+ script_callbacks.ui_settings_callback()
+ opts.reorder()
+
def open_folder(f):
if not os.path.exists(f):
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
@@ -1564,7 +1567,8 @@ Requested path was: {f}
previous_section = item.section
- gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value=''.format(item.section[1]))
+ elem_id, text = item.section
+ gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value=''.format(text))
if k in quicksettings_names:
quicksettings_list.append((i, k, item))
--
cgit v1.2.3
From dbc8ab65f6d496459a76547776b656c96ad1350d Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 19:19:17 +0300
Subject: typo
---
modules/script_callbacks.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 1270e50f..5bcccd67 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -48,6 +48,6 @@ def on_ui_tabs(callback):
def on_ui_settings(callback):
- """register a function to be called before UI settingsare populated; add your settings
+ """register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
callbacks_ui_settings.append(callback)
--
cgit v1.2.3
From 72383abacdc6a101704a6f73758ce4d0bb68c9d1 Mon Sep 17 00:00:00 2001
From: Greendayle
Date: Sat, 22 Oct 2022 16:50:07 +0200
Subject: Deepdanbooru linux fix
---
modules/deepbooru.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 8914662d..3c34ab7c 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -50,7 +50,8 @@ def create_deepbooru_process(threshold, deepbooru_opts):
the tags.
"""
from modules import shared # prevents circular reference
- shared.deepbooru_process_manager = multiprocessing.Manager()
+ context = multiprocessing.get_context("spawn")
+ shared.deepbooru_process_manager = context.Manager()
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
--
cgit v1.2.3
From e38625011cd4955da4bc67fe95d1d0f4c0c53899 Mon Sep 17 00:00:00 2001
From: Greendayle
Date: Sat, 22 Oct 2022 16:56:52 +0200
Subject: fix part2
---
modules/deepbooru.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 3c34ab7c..8bbc90a4 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -55,7 +55,7 @@ def create_deepbooru_process(threshold, deepbooru_opts):
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
+ shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start()
--
cgit v1.2.3
From 324c7c732dd9afc3d4c397c354797ae5d655b514 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 20:09:37 +0300
Subject: record First pass size as 0x0 for #3328
---
modules/processing.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 372489f7..27c669b0 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -524,6 +524,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
+ self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
+
if self.firstphase_width == 0 or self.firstphase_height == 0:
desired_pixel_count = 512 * 512
actual_pixel_count = self.width * self.height
@@ -545,7 +547,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height
- self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
--
cgit v1.2.3
From 0df94d3fcf9d1fc47c4d39039352a3d5b3380c1f Mon Sep 17 00:00:00 2001
From: MrCheeze
Date: Sat, 22 Oct 2022 12:59:21 -0400
Subject: fix aesthetic gradients doing nothing after loading a different model
---
modules/sd_models.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f9b3063d..49dc3238 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -236,12 +236,11 @@ def load_model(checkpoint_info=None):
sd_model.to(shared.device)
sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
sd_model.eval()
shared.sd_model = sd_model
- script_callbacks.model_loaded_callback(sd_model)
-
print(f"Model loaded.")
return sd_model
@@ -268,6 +267,7 @@ def reload_model_weights(sd_model, info=None):
load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
--
cgit v1.2.3
From 321bacc6a9eaf4a25f31279f288fa752be507a20 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 20:15:12 +0300
Subject: call model_loaded_callback after setting shared.sd_model in case
scripts refer to it using that
---
modules/sd_models.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 49dc3238..e697bb72 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -236,11 +236,12 @@ def load_model(checkpoint_info=None):
sd_model.to(shared.device)
sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
sd_model.eval()
shared.sd_model = sd_model
+ script_callbacks.model_loaded_callback(sd_model)
+
print(f"Model loaded.")
return sd_model
--
cgit v1.2.3
From 24694e5983d0944b901892cb101878e6dec89a20 Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sun, 23 Oct 2022 01:57:58 +0900
Subject: Update hypernetwork.py
---
modules/hypernetworks/hypernetwork.py | 55 ++++++++++++++++++++++++++++-------
1 file changed, 44 insertions(+), 11 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 3bc71ee5..81132be4 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
+from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
@@ -268,6 +269,32 @@ def stack_conds(conds):
return torch.stack(conds)
+def log_statistics(loss_info:dict, key, value):
+ if key not in loss_info:
+ loss_info[key] = [value]
+ else:
+ loss_info[key].append(value)
+ if len(loss_info) > 1024:
+ loss_info.pop(0)
+
+
+def statistics(data):
+ total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
+ recent_data = data[-32:]
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})"
+ return total_information, recent_information
+
+
+def report_statistics(loss_info:dict):
+ keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
+ for key in keys:
+ info, recent = statistics(loss_info[key])
+ print("Loss statistics for file " + key)
+ print(info)
+ print(recent)
+
+
+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -310,7 +337,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
for weight in weights:
weight.requires_grad = True
- losses = torch.zeros((32,))
+ size = len(ds.indexes)
+ loss_dict = {}
+ losses = torch.zeros((size,))
+ previous_mean_loss = 0
+ print("Mean loss of {} elements".format(size))
last_saved_file = ""
last_saved_image = ""
@@ -329,7 +360,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
-
+ if loss_dict and i % size == 0:
+ previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
+
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
@@ -346,7 +379,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
del c
losses[hypernetwork.step % losses.shape[0]] = loss.item()
-
+ for entry in entries:
+ log_statistics(loss_dict, entry.filename, loss.item())
+
optimizer.zero_grad()
weights[0].grad = None
loss.backward()
@@ -359,10 +394,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
optimizer.step()
- mean_loss = losses.mean()
- if torch.isnan(mean_loss):
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
- pbar.set_description(f"loss: {mean_loss:.7f}")
+ pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
@@ -371,7 +405,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
- "loss": f"{mean_loss:.7f}",
+ "loss": f"{previous_mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
})
@@ -420,14 +454,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"""
-Loss: {mean_loss:.7f}
+Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
Last prompt: {html.escape(entries[0].cond_text)}
Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
"""
-
+
+ report_statistics(loss_dict)
checkpoint = sd_models.select_checkpoint()
hypernetwork.sd_checkpoint = checkpoint.hash
@@ -438,5 +473,3 @@ Last saved image: {html.escape(last_saved_image)}
hypernetwork.save(filename)
return hypernetwork, filename
-
-
--
cgit v1.2.3
From 4fdb53c1e9962507fc8336dad9a0fabfe6c418c0 Mon Sep 17 00:00:00 2001
From: Unnoen
Date: Wed, 19 Oct 2022 21:38:10 +1100
Subject: Generate grid preview for progress image
---
modules/sd_samplers.py | 26 +++++++++++++++++++++++++-
modules/shared.py | 1 +
modules/ui.py | 5 ++++-
3 files changed, 30 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index f58a29b9..74a480e5 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -7,7 +7,7 @@ import inspect
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser, devices, processing
+from modules import prompt_parser, devices, processing, images
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -89,6 +89,30 @@ def sample_to_image(samples):
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
+def samples_to_image_grid(samples):
+ progress_images = []
+ for i in range(len(samples)):
+ # Decode the samples individually to reduce VRAM usage at the cost of a bit of speed.
+ x_sample = processing.decode_first_stage(shared.sd_model, samples[i:i+1])[0]
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ progress_images.append(Image.fromarray(x_sample))
+
+ return images.image_grid(progress_images)
+
+def samples_to_image_grid_combined(samples):
+ progress_images = []
+ # Decode all samples at once to increase speed at the cost of VRAM usage.
+ x_samples = processing.decode_first_stage(shared.sd_model, samples)
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for x_sample in x_samples:
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ progress_images.append(Image.fromarray(x_sample))
+
+ return images.image_grid(progress_images)
def store_latent(decoded):
state.current_latent = decoded
diff --git a/modules/shared.py b/modules/shared.py
index d9cb65ef..95d6e225 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -294,6 +294,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
+ "progress_decode_combined": OptionInfo(False, "Decode all progress images at once. (Slighty speeds up progress generation but consumes significantly more VRAM with large batches.)"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
diff --git a/modules/ui.py b/modules/ui.py
index 56c233ab..de0abc7e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -318,7 +318,10 @@ def check_progress_call(id_part):
if shared.parallel_processing_allowed:
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
+ if opts.progress_decode_combined:
+ shared.state.current_image = modules.sd_samplers.samples_to_image_grid_combined(shared.state.current_latent)
+ else:
+ shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
image = shared.state.current_image
--
cgit v1.2.3
From d213d6ca6f90094cb45c11e2f3cb37d25a8d1f94 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 20:48:13 +0300
Subject: removed the option to use 2x more memory when generating previews
added an option to always only show one image in previews removed duplicate
code
---
modules/sd_samplers.py | 35 ++++++++++-------------------------
modules/shared.py | 2 +-
modules/ui.py | 6 +++---
3 files changed, 14 insertions(+), 29 deletions(-)
(limited to 'modules')
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 74a480e5..0b408a70 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -71,6 +71,7 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
@@ -82,37 +83,21 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-def sample_to_image(samples):
- x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
+def single_sample_to_image(sample):
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
+
+def sample_to_image(samples):
+ return single_sample_to_image(samples[0])
+
+
def samples_to_image_grid(samples):
- progress_images = []
- for i in range(len(samples)):
- # Decode the samples individually to reduce VRAM usage at the cost of a bit of speed.
- x_sample = processing.decode_first_stage(shared.sd_model, samples[i:i+1])[0]
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- progress_images.append(Image.fromarray(x_sample))
-
- return images.image_grid(progress_images)
-
-def samples_to_image_grid_combined(samples):
- progress_images = []
- # Decode all samples at once to increase speed at the cost of VRAM usage.
- x_samples = processing.decode_first_stage(shared.sd_model, samples)
- x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
-
- for x_sample in x_samples:
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- progress_images.append(Image.fromarray(x_sample))
-
- return images.image_grid(progress_images)
+ return images.image_grid([single_sample_to_image(sample) for sample in samples])
+
def store_latent(decoded):
state.current_latent = decoded
diff --git a/modules/shared.py b/modules/shared.py
index 95d6e225..25bfc895 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -294,7 +294,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
- "progress_decode_combined": OptionInfo(False, "Decode all progress images at once. (Slighty speeds up progress generation but consumes significantly more VRAM with large batches.)"),
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
diff --git a/modules/ui.py b/modules/ui.py
index de0abc7e..ffa14cac 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -318,10 +318,10 @@ def check_progress_call(id_part):
if shared.parallel_processing_allowed:
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- if opts.progress_decode_combined:
- shared.state.current_image = modules.sd_samplers.samples_to_image_grid_combined(shared.state.current_latent)
- else:
+ if opts.show_progress_grid:
shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
+ else:
+ shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
image = shared.state.current_image
--
cgit v1.2.3
From be748e8b086bd9834d08bdd9160649a5e7700af7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 22:05:22 +0300
Subject: add --freeze-settings commandline argument to disable changing
settings
---
modules/shared.py | 1 +
modules/ui.py | 11 +++++++++--
2 files changed, 10 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/shared.py b/modules/shared.py
index 25bfc895..b55371d3 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -64,6 +64,7 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
+parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
diff --git a/modules/ui.py b/modules/ui.py
index ffa14cac..2311572c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -580,6 +580,9 @@ def apply_setting(key, value):
if value is None:
return gr.update()
+ if shared.cmd_opts.freeze_settings:
+ return gr.update()
+
# dont allow model to be swapped when model hash exists in prompt
if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
return gr.update()
@@ -1501,6 +1504,8 @@ Requested path was: {f}
def run_settings(*args):
changed = 0
+ assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
+
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
@@ -1530,6 +1535,8 @@ Requested path was: {f}
return f'{changed} settings changed.', opts.dumpjson()
def run_settings_single(value, key):
+ assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
+
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
@@ -1582,7 +1589,7 @@ Requested path was: {f}
elem_id, text = item.section
gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value=''.format(text))
- if k in quicksettings_names:
+ if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
else:
@@ -1615,7 +1622,7 @@ Requested path was: {f}
def reload_scripts():
modules.scripts.reload_script_body_only()
- reload_javascript() # need to refresh the html page
+ reload_javascript() # need to refresh the html page
reload_script_bodies.click(
fn=reload_scripts,
--
cgit v1.2.3
From ca5a9e79dc28eeaa3a161427a82e34703bf15765 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sat, 22 Oct 2022 22:06:54 +0300
Subject: fix for img2img color correction in a batch #3218
---
modules/processing.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/processing.py b/modules/processing.py
index 27c669b0..b1877b80 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -403,8 +403,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if (len(prompts) == 0):
break
- #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
- #c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
@@ -716,6 +714,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
+
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
+ self.color_corrections = self.color_corrections * self.batch_size
+
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
--
cgit v1.2.3
From 48dbf99e84045ee7af55bc5b1b86492a240e631e Mon Sep 17 00:00:00 2001
From: AngelBottomless <35677394+aria1th@users.noreply.github.com>
Date: Sun, 23 Oct 2022 04:17:16 +0900
Subject: Allow tracking real-time loss
Someone had 6000 images in their dataset, and it was shown as 0, which was confusing.
This will allow tracking real time dataset-average loss for registered objects.
---
modules/hypernetworks/hypernetwork.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 81132be4..99fd0f8f 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -360,7 +360,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
- if loss_dict and i % size == 0:
+ if len(loss_dict) > 0:
previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
scheduler.apply(optimizer, hypernetwork.step)
--
cgit v1.2.3
From 1fbfc052eb529d8cf8ce5baf578bcf93d0280c29 Mon Sep 17 00:00:00 2001
From: DepFA <35278260+dfaker@users.noreply.github.com>
Date: Sun, 23 Oct 2022 05:43:34 +0100
Subject: Update hypernetwork.py
---
modules/hypernetworks/hypernetwork.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
(limited to 'modules')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 99fd0f8f..98a7b62e 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -288,10 +288,13 @@ def statistics(data):
def report_statistics(loss_info:dict):
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
for key in keys:
- info, recent = statistics(loss_info[key])
- print("Loss statistics for file " + key)
- print(info)
- print(recent)
+ try:
+ print("Loss statistics for file " + key)
+ info, recent = statistics(loss_info[key])
+ print(info)
+ print(recent)
+ except Exception as e:
+ print(e)
--
cgit v1.2.3
From a7c213d0f5ebb10722629b8490a5863f9ce6c4fa Mon Sep 17 00:00:00 2001
From: Stephen
Date: Fri, 21 Oct 2022 19:27:40 -0400
Subject: [API][Feature] - Add img2img API endpoint
---
modules/api/api.py | 58 +++++++++++++++++++++++++++++++++++++++++++----
modules/api/processing.py | 11 +++++++--
modules/processing.py | 2 +-
3 files changed, 63 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 5b0c934e..a04f2428 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,5 +1,5 @@
-from modules.api.processing import StableDiffusionProcessingAPI
-from modules.processing import StableDiffusionProcessingTxt2Img, process_images
+from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, Json
import json
import io
import base64
+from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
@@ -18,6 +19,11 @@ class TextToImageResponse(BaseModel):
parameters: Json
info: Json
+class ImageToImageResponse(BaseModel):
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: Json
+ info: Json
+
class Api:
def __init__(self, app, queue_lock):
@@ -25,8 +31,9 @@ class Api:
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
- def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
+ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
@@ -54,8 +61,49 @@ class Api:
- def img2imgapi(self):
- raise NotImplementedError
+ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
+ sampler_index = sampler_to_index(img2imgreq.sampler_index)
+
+ if sampler_index is None:
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
+
+ init_images = img2imgreq.init_images
+ if init_images is None:
+ raise HTTPException(status_code=404, detail="Init image not found")
+
+
+ populate = img2imgreq.copy(update={ # Override __init__ params
+ "sd_model": shared.sd_model,
+ "sampler_index": sampler_index[0],
+ "do_not_save_samples": True,
+ "do_not_save_grid": True
+ }
+ )
+ p = StableDiffusionProcessingImg2Img(**vars(populate))
+
+ imgs = []
+ for img in init_images:
+ # if has a comma, deal with prefix
+ if "," in img:
+ img = img.split(",")[1]
+ # convert base64 to PIL image
+ img = base64.b64decode(img)
+ img = Image.open(io.BytesIO(img))
+ imgs = [img] * p.batch_size
+
+ p.init_images = imgs
+ # Override object param
+ with self.queue_lock:
+ processed = process_images(p)
+
+ b64images = []
+ for i in processed.images:
+ buffer = io.BytesIO()
+ i.save(buffer, format="png")
+ b64images.append(base64.b64encode(buffer.getvalue()))
+
+ return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))
def extrasapi(self):
raise NotImplementedError
diff --git a/modules/api/processing.py b/modules/api/processing.py
index 4c541241..9f1d65c0 100644
--- a/modules/api/processing.py
+++ b/modules/api/processing.py
@@ -1,7 +1,8 @@
+from array import array
from inflection import underscore
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model
-from modules.processing import StableDiffusionProcessingTxt2Img
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
import inspect
@@ -92,8 +93,14 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_mutation = True
return DynamicModel
-StableDiffusionProcessingAPI = PydanticModelGenerator(
+StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}]
+).generate_model()
+
+StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingImg2Img",
+ StableDiffusionProcessingImg2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}]
).generate_model()
\ No newline at end of file
diff --git a/modules/processing.py b/modules/processing.py
index b1877b80..1557ed8c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -623,7 +623,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
+ def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: str=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
--
cgit v1.2.3
From 9e1a8b7734a2881451a2efbf80def011ea41ba49 Mon Sep 17 00:00:00 2001
From: Stephen
Date: Sat, 22 Oct 2022 15:42:00 -0400
Subject: non-implemented mask with any type
---
modules/api/api.py | 4 ++++
modules/api/processing.py | 2 +-
modules/processing.py | 2 +-
3 files changed, 6 insertions(+), 2 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index a04f2428..3df6ff96 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -72,6 +72,10 @@ class Api:
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
+ mask = img2imgreq.mask
+ if mask:
+ raise HTTPException(status_code=400, detail="Mask not supported yet")
+
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
diff --git a/modules/api/processing.py b/modules/api/processing.py
index 9f1d65c0..f551fa35 100644
--- a/modules/api/processing.py
+++ b/modules/api/processing.py
@@ -102,5 +102,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
).generate_model()
\ No newline at end of file
diff --git a/modules/processing.py b/modules/processing.py
index 1557ed8c..ff83023c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -623,7 +623,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: str=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
+ def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
--
cgit v1.2.3
From 5dc0739ecdc1ade8fcf4eb77f2a503ef12489f32 Mon Sep 17 00:00:00 2001
From: Stephen
Date: Sat, 22 Oct 2022 17:10:28 -0400
Subject: working mask
---
modules/api/api.py | 20 ++++++++++++--------
1 file changed, 12 insertions(+), 8 deletions(-)
(limited to 'modules')
diff --git a/modules/api/api.py b/modules/api/api.py
index 3df6ff96..3caa83a4 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -33,6 +33,14 @@ class Api:
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
+ def __base64_to_image(self, base64_string):
+ # if has a comma, deal with prefix
+ if "," in base64_string:
+ base64_string = base64_string.split(",")[1]
+ imgdata = base64.b64decode(base64_string)
+ # convert base64 to PIL image
+ return Image.open(io.BytesIO(imgdata))
+
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -74,26 +82,22 @@ class Api:
mask = img2imgreq.mask
if mask:
- raise HTTPException(status_code=400, detail="Mask not supported yet")
+ mask = self.__base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
- "do_not_save_grid": True
+ "do_not_save_grid": True,
+ "mask": mask
}
)
p = StableDiffusionProcessingImg2Img(**vars(populate))
imgs = []
for img in init_images:
- # if has a comma, deal with prefix
- if "," in img:
- img = img.split(",")[1]
- # convert base64 to PIL image
- img = base64.b64decode(img)
- img = Image.open(io.BytesIO(img))
+ img = self.__base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs
--
cgit v1.2.3