diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-05-10 08:19:16 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-05-10 08:19:16 +0000 |
commit | 550256db1ce18778a9d56ff343d844c61b9f9b83 (patch) | |
tree | a17e8fd9cb475381c361844970ba2d9111938b6d /modules/codeformer | |
parent | 028d3f6425d85f122027c127fba8bcbf4f66ee75 (diff) | |
download | stable-diffusion-webui-gfx803-550256db1ce18778a9d56ff343d844c61b9f9b83.tar.gz stable-diffusion-webui-gfx803-550256db1ce18778a9d56ff343d844c61b9f9b83.tar.bz2 stable-diffusion-webui-gfx803-550256db1ce18778a9d56ff343d844c61b9f9b83.zip |
ruff manual fixes
Diffstat (limited to 'modules/codeformer')
-rw-r--r-- | modules/codeformer/codeformer_arch.py | 7 | ||||
-rw-r--r-- | modules/codeformer/vqgan_arch.py | 4 |
2 files changed, 7 insertions, 4 deletions
diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py index 00c407de..ff1c0b4b 100644 --- a/modules/codeformer/codeformer_arch.py +++ b/modules/codeformer/codeformer_arch.py @@ -161,10 +161,13 @@ class Fuse_sft_block(nn.Module): class CodeFormer(VQAutoEncoder): def __init__(self, dim_embd=512, n_head=8, n_layers=9, codebook_size=1024, latent_size=256, - connect_list=['32', '64', '128', '256'], - fix_modules=['quantize','generator']): + connect_list=None, + fix_modules=None): super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) + connect_list = connect_list or ['32', '64', '128', '256'] + fix_modules = fix_modules or ['quantize', 'generator'] + if fix_modules is not None: for module in fix_modules: for param in getattr(self, module).parameters(): diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py index 820e6b12..b24a0394 100644 --- a/modules/codeformer/vqgan_arch.py +++ b/modules/codeformer/vqgan_arch.py @@ -326,7 +326,7 @@ class Generator(nn.Module): @ARCH_REGISTRY.register() class VQAutoEncoder(nn.Module): - def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): super().__init__() logger = get_root_logger() @@ -337,7 +337,7 @@ class VQAutoEncoder(nn.Module): self.embed_dim = emb_dim self.ch_mult = ch_mult self.resolution = img_size - self.attn_resolutions = attn_resolutions + self.attn_resolutions = attn_resolutions or [16] self.quantizer_type = quantizer self.encoder = Encoder( self.in_channels, |