aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-07 09:32:28 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-09-07 09:32:28 +0000
commit6a9b33c848281cb02f38764e4f91ef767f5e3edd (patch)
treed3af251b9e38b0187f2cfe6ad4d81bb5e3ae3eeb
parent9cb3cc3a2f5f419dd594f3322fa35113a6ed2391 (diff)
downloadstable-diffusion-webui-gfx803-6a9b33c848281cb02f38764e4f91ef767f5e3edd.tar.gz
stable-diffusion-webui-gfx803-6a9b33c848281cb02f38764e4f91ef767f5e3edd.tar.bz2
stable-diffusion-webui-gfx803-6a9b33c848281cb02f38764e4f91ef767f5e3edd.zip
codeformer support
-rw-r--r--modules/codeformer/codeformer_arch.py276
-rw-r--r--modules/codeformer/vqgan_arch.py435
-rw-r--r--modules/codeformer_model.py108
-rw-r--r--modules/face_restoration.py19
-rw-r--r--modules/gfpgan_model.py20
-rw-r--r--modules/img2img.py4
-rw-r--r--modules/paths.py13
-rw-r--r--modules/processing.py12
-rw-r--r--modules/shared.py6
-rw-r--r--modules/txt2img.py4
-rw-r--r--modules/ui.py15
-rw-r--r--webui.bat19
-rw-r--r--webui.py20
13 files changed, 919 insertions, 32 deletions
diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py
new file mode 100644
index 00000000..0eff93dc
--- /dev/null
+++ b/modules/codeformer/codeformer_arch.py
@@ -0,0 +1,276 @@
+import math
+import numpy as np
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+from typing import Optional, List
+
+from modules.codeformer.vqgan_arch import *
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class TransformerSALayer(nn.Module):
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
+ # Implementation of Feedforward model - MLP
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
+
+ self.norm1 = nn.LayerNorm(embed_dim)
+ self.norm2 = nn.LayerNorm(embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self, tgt,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+
+ # self attention
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+
+ # ffn
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+class Fuse_sft_block(nn.Module):
+ def __init__(self, in_ch, out_ch):
+ super().__init__()
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
+
+ self.scale = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ self.shift = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ def forward(self, enc_feat, dec_feat, w=1):
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
+ scale = self.scale(enc_feat)
+ shift = self.shift(enc_feat)
+ residual = w * (dec_feat * scale + shift)
+ out = dec_feat + residual
+ return out
+
+
+@ARCH_REGISTRY.register()
+class CodeFormer(VQAutoEncoder):
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
+ codebook_size=1024, latent_size=256,
+ connect_list=['32', '64', '128', '256'],
+ fix_modules=['quantize','generator']):
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
+
+ if fix_modules is not None:
+ for module in fix_modules:
+ for param in getattr(self, module).parameters():
+ param.requires_grad = False
+
+ self.connect_list = connect_list
+ self.n_layers = n_layers
+ self.dim_embd = dim_embd
+ self.dim_mlp = dim_embd*2
+
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
+ self.feat_emb = nn.Linear(256, self.dim_embd)
+
+ # transformer
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
+ for _ in range(self.n_layers)])
+
+ # logits_predict head
+ self.idx_pred_layer = nn.Sequential(
+ nn.LayerNorm(dim_embd),
+ nn.Linear(dim_embd, codebook_size, bias=False))
+
+ self.channels = {
+ '16': 512,
+ '32': 256,
+ '64': 256,
+ '128': 128,
+ '256': 128,
+ '512': 64,
+ }
+
+ # after second residual block for > 16, before attn layer for ==16
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
+ # after first residual block for > 16, before attn layer for ==16
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
+
+ # fuse_convs_dict
+ self.fuse_convs_dict = nn.ModuleDict()
+ for f_size in self.connect_list:
+ in_ch = self.channels[f_size]
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
+ # ################### Encoder #####################
+ enc_feat_dict = {}
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
+ for i, block in enumerate(self.encoder.blocks):
+ x = block(x)
+ if i in out_list:
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
+
+ lq_feat = x
+ # ################# Transformer ###################
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
+ # BCHW -> BC(HW) -> (HW)BC
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
+ query_emb = feat_emb
+ # Transformer encoder
+ for layer in self.ft_layers:
+ query_emb = layer(query_emb, query_pos=pos_emb)
+
+ # output logits
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
+
+ if code_only: # for training stage II
+ # logits doesn't need softmax before cross_entropy loss
+ return logits, lq_feat
+
+ # ################# Quantization ###################
+ # if self.training:
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
+ # # b(hw)c -> bc(hw) -> bchw
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
+ # ------------
+ soft_one_hot = F.softmax(logits, dim=2)
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
+ # preserve gradients
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
+
+ if detach_16:
+ quant_feat = quant_feat.detach() # for training stage III
+ if adain:
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
+
+ # ################## Generator ####################
+ x = quant_feat
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
+
+ for i, block in enumerate(self.generator.blocks):
+ x = block(x)
+ if i in fuse_list: # fuse after i-th block
+ f_size = str(x.shape[-1])
+ if w>0:
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
+ out = x
+ # logits doesn't need softmax before cross_entropy loss
+ return out, logits, lq_feat \ No newline at end of file
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
new file mode 100644
index 00000000..f6dfcf4c
--- /dev/null
+++ b/modules/codeformer/vqgan_arch.py
@@ -0,0 +1,435 @@
+'''
+VQGAN code, adapted from the original created by the Unleashing Transformers authors:
+https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
+
+'''
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+@torch.jit.script
+def swish(x):
+ return x*torch.sigmoid(x)
+
+
+# Define VQVAE classes
+class VectorQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.emb_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
+
+ mean_distance = torch.mean(d)
+ # find closest encodings
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+ min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
+ # [0-1], higher score, higher confidence
+ min_encoding_scores = torch.exp(-min_encoding_scores/10)
+
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, {
+ "perplexity": perplexity,
+ "min_encodings": min_encodings,
+ "min_encoding_indices": min_encoding_indices,
+ "min_encoding_scores": min_encoding_scores,
+ "mean_distance": mean_distance
+ }
+
+ def get_codebook_feat(self, indices, shape):
+ # input indices: batch*token_num -> (batch*token_num)*1
+ # shape: batch, height, width, channel
+ indices = indices.view(-1,1)
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
+ min_encodings.scatter_(1, indices, 1)
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None: # reshape back to match original input shape
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
+ super().__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
+ self.embed = nn.Embedding(codebook_size, emb_dim)
+
+ def forward(self, z):
+ hard = self.straight_through if self.training else True
+
+ logits = self.proj(z)
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
+
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
+
+ return z_q, diff, {
+ "min_encoding_indices": min_encoding_indices
+ }
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels=None):
+ super(ResBlock, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.norm1 = normalize(in_channels)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = normalize(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x_in):
+ x = x_in
+ x = self.norm1(x)
+ x = swish(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = swish(x)
+ x = self.conv2(x)
+ if self.in_channels != self.out_channels:
+ x_in = self.conv_out(x_in)
+
+ return x + x_in
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h*w)
+ q = q.permute(0, 2, 1)
+ k = k.reshape(b, c, h*w)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = F.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h*w)
+ w_ = w_.permute(0, 2, 1)
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.attn_resolutions = attn_resolutions
+
+ curr_res = self.resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+
+ blocks = []
+ # initial convultion
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
+
+ # residual and downsampling blocks, with attention on smaller res (16x16)
+ for i in range(self.num_resolutions):
+ block_in_ch = nf * in_ch_mult[i]
+ block_out_ch = nf * ch_mult[i]
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+ if curr_res in attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != self.num_resolutions - 1:
+ blocks.append(Downsample(block_in_ch))
+ curr_res = curr_res // 2
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ # normalise and convert to latent size
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+class Generator(nn.Module):
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.ch_mult = ch_mult
+ self.num_resolutions = len(self.ch_mult)
+ self.num_res_blocks = res_blocks
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.in_channels = emb_dim
+ self.out_channels = 3
+ block_in_ch = self.nf * self.ch_mult[-1]
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
+
+ blocks = []
+ # initial conv
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ for i in reversed(range(self.num_resolutions)):
+ block_out_ch = self.nf * self.ch_mult[i]
+
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+
+ if curr_res in self.attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != 0:
+ blocks.append(Upsample(block_in_ch))
+ curr_res = curr_res * 2
+
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
+
+ self.blocks = nn.ModuleList(blocks)
+
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+@ARCH_REGISTRY.register()
+class VQAutoEncoder(nn.Module):
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
+ super().__init__()
+ logger = get_root_logger()
+ self.in_channels = 3
+ self.nf = nf
+ self.n_blocks = res_blocks
+ self.codebook_size = codebook_size
+ self.embed_dim = emb_dim
+ self.ch_mult = ch_mult
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.quantizer_type = quantizer
+ self.encoder = Encoder(
+ self.in_channels,
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+ if self.quantizer_type == "nearest":
+ self.beta = beta #0.25
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
+ elif self.quantizer_type == "gumbel":
+ self.gumbel_num_hiddens = emb_dim
+ self.straight_through = gumbel_straight_through
+ self.kl_weight = gumbel_kl_weight
+ self.quantize = GumbelQuantizer(
+ self.codebook_size,
+ self.embed_dim,
+ self.gumbel_num_hiddens,
+ self.straight_through,
+ self.kl_weight
+ )
+ self.generator = Generator(
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_ema' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
+ else:
+ raise ValueError(f'Wrong params!')
+
+
+ def forward(self, x):
+ x = self.encoder(x)
+ quant, codebook_loss, quant_stats = self.quantize(x)
+ x = self.generator(quant)
+ return x, codebook_loss, quant_stats
+
+
+
+# patch based discriminator
+@ARCH_REGISTRY.register()
+class VQGANDiscriminator(nn.Module):
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
+ super().__init__()
+
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
+ ndf_mult = 1
+ ndf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n, 8)
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n_layers, 8)
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
+ self.main = nn.Sequential(*layers)
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_d' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ else:
+ raise ValueError(f'Wrong params!')
+
+ def forward(self, x):
+ return self.main(x) \ No newline at end of file
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
new file mode 100644
index 00000000..836417c2
--- /dev/null
+++ b/modules/codeformer_model.py
@@ -0,0 +1,108 @@
+import os
+import sys
+import traceback
+import torch
+
+from modules import shared
+from modules.paths import script_path
+import modules.shared
+import modules.face_restoration
+from importlib import reload
+
+# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
+# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
+# I am making a choice to include some files from codeformer to work around this issue.
+
+pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+
+have_codeformer = False
+
+def setup_codeformer():
+ path = modules.paths.paths.get("CodeFormer", None)
+ if path is None:
+ return
+
+
+ # both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own
+ #stored_sys_path = sys.path
+ #sys.path = [path] + sys.path
+
+ try:
+ from torchvision.transforms.functional import normalize
+ from modules.codeformer.codeformer_arch import CodeFormer
+ from basicsr.utils.download_util import load_file_from_url
+ from basicsr.utils import imwrite, img2tensor, tensor2img
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
+ from modules.shared import cmd_opts
+
+ net_class = CodeFormer
+
+ class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
+ def name(self):
+ return "CodeFormer"
+
+ def __init__(self):
+ self.net = None
+ self.face_helper = None
+
+ def create_models(self):
+
+ if self.net is not None and self.face_helper is not None:
+ return self.net, self.face_helper
+
+ net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(shared.device)
+ ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
+ checkpoint = torch.load(ckpt_path)['params_ema']
+ net.load_state_dict(checkpoint)
+ net.eval()
+
+ face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=shared.device)
+
+ if not cmd_opts.unload_gfpgan:
+ self.net = net
+ self.face_helper = face_helper
+
+ return net, face_helper
+
+ def restore(self, np_image):
+ np_image = np_image[:, :, ::-1]
+
+ net, face_helper = self.create_models()
+ face_helper.clean_all()
+ face_helper.read_image(np_image)
+ face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
+ face_helper.align_warp_face()
+
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(shared.device)
+
+ try:
+ with torch.no_grad():
+ output = net(cropped_face_t, w=shared.opts.code_former_weight, adain=True)[0]
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+ del output
+ torch.cuda.empty_cache()
+ except Exception as error:
+ print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+
+ restored_face = restored_face.astype('uint8')
+ face_helper.add_restored_face(restored_face)
+
+ face_helper.get_inverse_affine(None)
+
+ restored_img = face_helper.paste_faces_to_input_image()
+ restored_img = restored_img[:, :, ::-1]
+ return restored_img
+
+ global have_codeformer
+ have_codeformer = True
+ shared.face_restorers.append(FaceRestorerCodeFormer())
+
+ except Exception:
+ print("Error setting up CodeFormer:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ # sys.path = stored_sys_path
diff --git a/modules/face_restoration.py b/modules/face_restoration.py
new file mode 100644
index 00000000..4ae53d21
--- /dev/null
+++ b/modules/face_restoration.py
@@ -0,0 +1,19 @@
+from modules import shared
+
+
+class FaceRestoration:
+ def name(self):
+ return "None"
+
+ def restore(self, np_image):
+ return np_image
+
+
+def restore_faces(np_image):
+ face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
+ if len(face_restorers) == 0:
+ return np_image
+
+ face_restorer = face_restorers[0]
+
+ return face_restorer.restore(np_image)
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 334a1b7f..f697326c 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -2,12 +2,15 @@ import os
import sys
import traceback
-from modules.paths import script_path
+from modules import shared
from modules.shared import cmd_opts
-import modules.shared
+from modules.paths import script_path
+import modules.face_restoration
def gfpgan_model_path():
+ from modules.shared import cmd_opts
+
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
found = [x for x in files if os.path.exists(x)]
@@ -62,6 +65,19 @@ def setup_gfpgan():
global gfpgan_constructor
gfpgan_constructor = GFPGANer
+
+ class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
+ def name(self):
+ return "GFPGAN"
+
+ def restore(self, np_image):
+ np_image_bgr = np_image[:, :, ::-1]
+ cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ np_image = gfpgan_output_bgr[:, :, ::-1]
+
+ return np_image
+
+ shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
print("Error setting up GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/img2img.py b/modules/img2img.py
index 4d8d4818..3129798d 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import modules.images as images
import modules.scripts
-def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
+def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
is_inpaint = mode == 1
is_loopback = mode == 2
is_upscale = mode == 3
@@ -36,7 +36,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
cfg_scale=cfg_scale,
width=width,
height=height,
- use_GFPGAN=use_GFPGAN,
+ restore_faces=restore_faces,
tiling=tiling,
init_images=[image],
mask=mask,
diff --git a/modules/paths.py b/modules/paths.py
index e1559bc7..6cf5e2b6 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -5,7 +5,7 @@ import sys
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.insert(0, script_path)
-# use current directory as SD dir if it has related files, otherwise parent dir of script as stated in guide
+# search for directory of stable diffsuion in following palces
sd_path = None
possible_sd_paths = ['.', os.path.dirname(script_path), os.path.join(script_path, 'repositories/stable-diffusion')]
for possible_sd_path in possible_sd_paths:
@@ -14,14 +14,19 @@ for possible_sd_path in possible_sd_paths:
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + possible_sd_paths
-# add parent directory to path; this is where Stable diffusion repo should be
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion'),
- (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers')
+ (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'),
+ (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
]
+
+paths = {}
+
for d, must_exist, what in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path):
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
else:
- sys.path.append(os.path.join(script_path, d))
+ d = os.path.abspath(d)
+ sys.path.append(d)
+ paths[what] = d
diff --git a/modules/processing.py b/modules/processing.py
index 78bc73b9..49474b73 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -14,7 +14,7 @@ from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
-import modules.gfpgan_model as gfpgan
+import modules.face_restoration
import modules.images as images
# some of those options should not be changed at all because they would break the model, so I removed them from options.
@@ -29,7 +29,7 @@ def torch_gc():
class StableDiffusionProcessing:
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, use_GFPGAN=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -44,7 +44,7 @@ class StableDiffusionProcessing:
self.cfg_scale: float = cfg_scale
self.width: int = width
self.height: int = height
- self.use_GFPGAN: bool = use_GFPGAN
+ self.restore_faces: bool = restore_faces
self.tiling: bool = tiling
self.do_not_save_samples: bool = do_not_save_samples
self.do_not_save_grid: bool = do_not_save_grid
@@ -136,7 +136,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
"Sampler": samplers[p.sampler_index].name,
"CFG scale": p.cfg_scale,
"Seed": all_seeds[position_in_batch + iteration * p.batch_size],
- "GFPGAN": ("GFPGAN" if p.use_GFPGAN else None),
+ "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
}
@@ -193,10 +193,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
- if p.use_GFPGAN:
+ if p.restore_faces:
torch_gc()
- x_sample = gfpgan.gfpgan_fix_faces(x_sample)
+ x_sample = modules.face_restoration.restore_faces(x_sample)
image = Image.fromarray(x_sample)
diff --git a/modules/shared.py b/modules/shared.py
index fb5bbd43..12082895 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -1,11 +1,13 @@
import argparse
import json
import os
+
import gradio as gr
import torch
import modules.artists
from modules.paths import script_path, sd_path
+import modules.codeformer_model
config_filename = "config.json"
@@ -40,6 +42,7 @@ device = gpu if torch.cuda.is_available() else cpu
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
+
class State:
interrupted = False
job = ""
@@ -65,6 +68,7 @@ state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
+face_restorers = []
def find_any_font():
fonts = ['/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf']
@@ -116,6 +120,8 @@ class Options:
"upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}),
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
+ "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
+ "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
}
def __init__(self):
diff --git a/modules/txt2img.py b/modules/txt2img.py
index dfce49ff..fd81ff0f 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -6,7 +6,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args):
+def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -21,7 +21,7 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u
cfg_scale=cfg_scale,
width=width,
height=height,
- use_GFPGAN=use_GFPGAN,
+ restore_faces=restore_faces,
tiling=tiling,
)
diff --git a/modules/ui.py b/modules/ui.py
index 92d8bcdd..05161b00 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -206,7 +206,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
with gr.Row():
- use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False)
with gr.Row():
@@ -253,7 +253,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
negative_prompt,
steps,
sampler_index,
- use_gfpgan,
+ restore_faces,
tiling,
batch_count,
batch_size,
@@ -335,7 +335,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
inpainting_mask_invert = gr.Radio(label='Masking mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", visible=False)
with gr.Row():
- use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False)
sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
@@ -425,7 +425,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
sampler_index,
mask_blur,
inpainting_fill,
- use_gfpgan,
+ restore_faces,
tiling,
switch_mode,
batch_count,
@@ -521,7 +521,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
with gr.Group():
- gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=0, interactive=gfpgan.have_gfpgan)
+ face_restoration_blending = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Faces restoration visibility", value=0, interactive=len(shared.face_restorers) > 1)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
@@ -534,7 +534,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
fn=run_extras,
inputs=[
image,
- gfpgan_strength,
+ face_restoration_blending,
upscaling_resize,
extras_upscaler_1,
extras_upscaler_2,
@@ -585,7 +585,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
t = type(info.default)
if info.component is not None:
- item = info.component(label=info.label, value=fun, **(info.component_args or {}))
+ args = info.component_args() if callable(info.component_args) else info.component_args
+ item = info.component(label=info.label, value=fun, **(args or {}))
elif t == str:
item = gr.Textbox(label=info.label, value=fun, lines=1)
elif t == int:
diff --git a/webui.bat b/webui.bat
index ed724362..055a19b0 100644
--- a/webui.bat
+++ b/webui.bat
@@ -92,6 +92,7 @@ echo Installing requirements...
%PYTHON% -m pip install -r %REQS_FILE% --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :update_numpy
goto :show_stdout_stderr
+
:update_numpy
%PYTHON% -m pip install -U numpy --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
@@ -105,12 +106,28 @@ if %ERRORLEVEL% == 0 goto :clone_transformers
goto :show_stdout_stderr
:clone_transformers
-if exist repositories\taming-transformers goto :check_model
+if exist repositories\taming-transformers goto :clone_codeformer
echo Cloning Taming Transforming repository...
%GIT% clone https://github.com/CompVis/taming-transformers.git repositories\taming-transformers >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :clone_codeformer
+goto :show_stdout_stderr
+
+:clone_codeformer
+if exist repositories\CodeFormer goto :install_codeformer_reqs
+echo Cloning CodeFormer repository...
+%GIT% clone https://github.com/sczhou/CodeFormer.git repositories\CodeFormer >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :install_codeformer_reqs
+goto :show_stdout_stderr
+
+:install_codeformer_reqs
+%PYTHON% -c "import lpips" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :check_model
+echo Installing requirements for CodeFormer...
+%PYTHON% -m pip install -r repositories\CodeFormer\requirements.txt --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :check_model
goto :show_stdout_stderr
+
:check_model
dir model.ckpt >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :check_gfpgan
diff --git a/webui.py b/webui.py
index d20ff38f..c9a800dd 100644
--- a/webui.py
+++ b/webui.py
@@ -19,7 +19,9 @@ from modules.ui import plaintext_to_html
import modules.scripts
import modules.processing as processing
import modules.sd_hijack
-import modules.gfpgan_model as gfpgan
+import modules.codeformer_model
+import modules.gfpgan_model
+import modules.face_restoration
import modules.realesrgan_model as realesrgan
import modules.esrgan_model as esrgan
import modules.images as images
@@ -28,10 +30,12 @@ import modules.txt2img
import modules.img2img
+modules.codeformer_model.setup_codeformer()
+modules.gfpgan_model.setup_gfpgan()
+shared.face_restorers.append(modules.face_restoration.FaceRestoration())
+
esrgan.load_models(cmd_opts.esrgan_models_path)
realesrgan.setup_realesrgan()
-gfpgan.setup_gfpgan()
-
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
@@ -54,19 +58,19 @@ def load_model_from_config(config, ckpt, verbose=False):
cached_images = {}
-def run_extras(image, gfpgan_strength, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
+def run_extras(image, face_restoration_blending, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
processing.torch_gc()
image = image.convert("RGB")
outpath = opts.outdir_samples or opts.outdir_extras_samples
- if gfpgan.have_gfpgan is not None and gfpgan_strength > 0:
- restored_img = gfpgan.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
+ if face_restoration_blending > 0:
+ restored_img = modules.face_restoration.restore_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
- if gfpgan_strength < 1.0:
- res = Image.blend(image, res, gfpgan_strength)
+ if face_restoration_blending < 1.0:
+ res = Image.blend(image, res, face_restoration_blending)
image = res