diff options
author | Roman Beltiukov <maybe.hello.world@gmail.com> | 2023-05-25 22:10:10 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-25 22:10:10 +0000 |
commit | b2530c965c2afd5512c5f9020251fd4be8f067e5 (patch) | |
tree | 0c1620e00ac4eddea514706a5c3bf3e03bd46c70 /modules/codeformer/vqgan_arch.py | |
parent | 09d9c3d287ee4543d285e0fde8b81603c9751a7e (diff) | |
parent | a6e653be26cc05f4438145fa0082816e9fbbf5fc (diff) | |
download | stable-diffusion-webui-gfx803-b2530c965c2afd5512c5f9020251fd4be8f067e5.tar.gz stable-diffusion-webui-gfx803-b2530c965c2afd5512c5f9020251fd4be8f067e5.tar.bz2 stable-diffusion-webui-gfx803-b2530c965c2afd5512c5f9020251fd4be8f067e5.zip |
Merge branch 'dev' into master
Diffstat (limited to 'modules/codeformer/vqgan_arch.py')
-rw-r--r-- | modules/codeformer/vqgan_arch.py | 44 |
1 files changed, 21 insertions, 23 deletions
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py index e7293683..09ee6660 100644 --- a/modules/codeformer/vqgan_arch.py +++ b/modules/codeformer/vqgan_arch.py @@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py ''' -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import copy from basicsr.utils import get_root_logger from basicsr.utils.registry import ARCH_REGISTRY def normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - + @torch.jit.script def swish(x): @@ -212,15 +210,15 @@ class AttnBlock(nn.Module): # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h*w) - q = q.permute(0, 2, 1) + q = q.permute(0, 2, 1) k = k.reshape(b, c, h*w) - w_ = torch.bmm(q, k) + w_ = torch.bmm(q, k) w_ = w_ * (int(c)**(-0.5)) w_ = F.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h*w) - w_ = w_.permute(0, 2, 1) + w_ = w_.permute(0, 2, 1) h_ = torch.bmm(v, w_) h_ = h_.reshape(b, c, h, w) @@ -272,18 +270,18 @@ class Encoder(nn.Module): def forward(self, x): for block in self.blocks: x = block(x) - + return x class Generator(nn.Module): def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): super().__init__() - self.nf = nf - self.ch_mult = ch_mult + self.nf = nf + self.ch_mult = ch_mult self.num_resolutions = len(self.ch_mult) self.num_res_blocks = res_blocks - self.resolution = img_size + self.resolution = img_size self.attn_resolutions = attn_resolutions self.in_channels = emb_dim self.out_channels = 3 @@ -317,29 +315,29 @@ class Generator(nn.Module): blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) self.blocks = nn.ModuleList(blocks) - + def forward(self, x): for block in self.blocks: x = block(x) - + return x - + @ARCH_REGISTRY.register() class VQAutoEncoder(nn.Module): - def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): super().__init__() logger = get_root_logger() - self.in_channels = 3 - self.nf = nf - self.n_blocks = res_blocks + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks self.codebook_size = codebook_size self.embed_dim = emb_dim self.ch_mult = ch_mult self.resolution = img_size - self.attn_resolutions = attn_resolutions + self.attn_resolutions = attn_resolutions or [16] self.quantizer_type = quantizer self.encoder = Encoder( self.in_channels, @@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module): self.kl_weight ) self.generator = Generator( - self.nf, + self.nf, self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, + self.ch_mult, + self.n_blocks, + self.resolution, self.attn_resolutions ) @@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module): raise ValueError('Wrong params!') def forward(self, x): - return self.main(x)
\ No newline at end of file + return self.main(x) |