diff options
author | Aarni Koskela <akx@iki.fi> | 2023-05-11 15:28:15 +0000 |
---|---|---|
committer | Aarni Koskela <akx@iki.fi> | 2023-05-11 17:29:11 +0000 |
commit | 49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c (patch) | |
tree | d79f004eae46bc1c49832f3c668a524107c30034 /modules/codeformer/vqgan_arch.py | |
parent | 431bc5a297ff7c17231b92b6c8f8152b2fab8553 (diff) | |
download | stable-diffusion-webui-gfx803-49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c.tar.gz stable-diffusion-webui-gfx803-49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c.tar.bz2 stable-diffusion-webui-gfx803-49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c.zip |
Autofix Ruff W (not W605) (mostly whitespace)
Diffstat (limited to 'modules/codeformer/vqgan_arch.py')
-rw-r--r-- | modules/codeformer/vqgan_arch.py | 38 |
1 files changed, 19 insertions, 19 deletions
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py index b24a0394..09ee6660 100644 --- a/modules/codeformer/vqgan_arch.py +++ b/modules/codeformer/vqgan_arch.py @@ -13,7 +13,7 @@ from basicsr.utils.registry import ARCH_REGISTRY def normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - + @torch.jit.script def swish(x): @@ -210,15 +210,15 @@ class AttnBlock(nn.Module): # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h*w) - q = q.permute(0, 2, 1) + q = q.permute(0, 2, 1) k = k.reshape(b, c, h*w) - w_ = torch.bmm(q, k) + w_ = torch.bmm(q, k) w_ = w_ * (int(c)**(-0.5)) w_ = F.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h*w) - w_ = w_.permute(0, 2, 1) + w_ = w_.permute(0, 2, 1) h_ = torch.bmm(v, w_) h_ = h_.reshape(b, c, h, w) @@ -270,18 +270,18 @@ class Encoder(nn.Module): def forward(self, x): for block in self.blocks: x = block(x) - + return x class Generator(nn.Module): def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): super().__init__() - self.nf = nf - self.ch_mult = ch_mult + self.nf = nf + self.ch_mult = ch_mult self.num_resolutions = len(self.ch_mult) self.num_res_blocks = res_blocks - self.resolution = img_size + self.resolution = img_size self.attn_resolutions = attn_resolutions self.in_channels = emb_dim self.out_channels = 3 @@ -315,24 +315,24 @@ class Generator(nn.Module): blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) self.blocks = nn.ModuleList(blocks) - + def forward(self, x): for block in self.blocks: x = block(x) - + return x - + @ARCH_REGISTRY.register() class VQAutoEncoder(nn.Module): def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): super().__init__() logger = get_root_logger() - self.in_channels = 3 - self.nf = nf - self.n_blocks = res_blocks + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks self.codebook_size = codebook_size self.embed_dim = emb_dim self.ch_mult = ch_mult @@ -363,11 +363,11 @@ class VQAutoEncoder(nn.Module): self.kl_weight ) self.generator = Generator( - self.nf, + self.nf, self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, + self.ch_mult, + self.n_blocks, + self.resolution, self.attn_resolutions ) @@ -432,4 +432,4 @@ class VQGANDiscriminator(nn.Module): raise ValueError('Wrong params!') def forward(self, x): - return self.main(x)
\ No newline at end of file + return self.main(x) |