diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-05-18 07:26:35 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-18 07:26:35 +0000 |
commit | 97e1cf69c04a3c62aa1bb19a14ffc948d9cc6c4e (patch) | |
tree | 7a24bdd31580fe0e4bf8d4205b57b55df0a2568d /modules/codeformer/vqgan_arch.py | |
parent | 484948f5c0b755a921c02cccbcacb2684a86a814 (diff) | |
parent | bb431df52bf3dc5e233e42907f2d8f56e4fb6c0c (diff) | |
download | stable-diffusion-webui-gfx803-97e1cf69c04a3c62aa1bb19a14ffc948d9cc6c4e.tar.gz stable-diffusion-webui-gfx803-97e1cf69c04a3c62aa1bb19a14ffc948d9cc6c4e.tar.bz2 stable-diffusion-webui-gfx803-97e1cf69c04a3c62aa1bb19a14ffc948d9cc6c4e.zip |
Merge branch 'dev' into master
Diffstat (limited to 'modules/codeformer/vqgan_arch.py')
-rw-r--r-- | modules/codeformer/vqgan_arch.py | 44 |
1 files changed, 21 insertions, 23 deletions
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py index e7293683..09ee6660 100644 --- a/modules/codeformer/vqgan_arch.py +++ b/modules/codeformer/vqgan_arch.py @@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py ''' -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import copy from basicsr.utils import get_root_logger from basicsr.utils.registry import ARCH_REGISTRY def normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - + @torch.jit.script def swish(x): @@ -212,15 +210,15 @@ class AttnBlock(nn.Module): # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h*w) - q = q.permute(0, 2, 1) + q = q.permute(0, 2, 1) k = k.reshape(b, c, h*w) - w_ = torch.bmm(q, k) + w_ = torch.bmm(q, k) w_ = w_ * (int(c)**(-0.5)) w_ = F.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h*w) - w_ = w_.permute(0, 2, 1) + w_ = w_.permute(0, 2, 1) h_ = torch.bmm(v, w_) h_ = h_.reshape(b, c, h, w) @@ -272,18 +270,18 @@ class Encoder(nn.Module): def forward(self, x): for block in self.blocks: x = block(x) - + return x class Generator(nn.Module): def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): super().__init__() - self.nf = nf - self.ch_mult = ch_mult + self.nf = nf + self.ch_mult = ch_mult self.num_resolutions = len(self.ch_mult) self.num_res_blocks = res_blocks - self.resolution = img_size + self.resolution = img_size self.attn_resolutions = attn_resolutions self.in_channels = emb_dim self.out_channels = 3 @@ -317,29 +315,29 @@ class Generator(nn.Module): blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) self.blocks = nn.ModuleList(blocks) - + def forward(self, x): for block in self.blocks: x = block(x) - + return x - + @ARCH_REGISTRY.register() class VQAutoEncoder(nn.Module): - def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): super().__init__() logger = get_root_logger() - self.in_channels = 3 - self.nf = nf - self.n_blocks = res_blocks + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks self.codebook_size = codebook_size self.embed_dim = emb_dim self.ch_mult = ch_mult self.resolution = img_size - self.attn_resolutions = attn_resolutions + self.attn_resolutions = attn_resolutions or [16] self.quantizer_type = quantizer self.encoder = Encoder( self.in_channels, @@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module): self.kl_weight ) self.generator = Generator( - self.nf, + self.nf, self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, + self.ch_mult, + self.n_blocks, + self.resolution, self.attn_resolutions ) @@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module): raise ValueError('Wrong params!') def forward(self, x): - return self.main(x)
\ No newline at end of file + return self.main(x) |