From 550256db1ce18778a9d56ff343d844c61b9f9b83 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:19:16 +0300 Subject: ruff manual fixes --- extensions-builtin/SwinIR/swinir_model_arch.py | 6 +++++- extensions-builtin/SwinIR/swinir_model_arch_v2.py | 11 +++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) (limited to 'extensions-builtin/SwinIR') diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py index 863f42db..75f7bedc 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ b/extensions-builtin/SwinIR/swinir_model_arch.py @@ -644,13 +644,17 @@ class SwinIR(nn.Module): """ def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + embed_dim=96, depths=None, num_heads=None, window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', **kwargs): super(SwinIR, self).__init__() + + depths = depths or [6, 6, 6, 6] + num_heads = num_heads or [6, 6, 6, 6] + num_in_ch = in_chans num_out_ch = in_chans num_feat = 64 diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py index 0e28ae6e..d4c0b0da 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py @@ -74,9 +74,12 @@ class WindowAttention(nn.Module): """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., - pretrained_window_size=[0, 0]): + pretrained_window_size=None): super().__init__() + + pretrained_window_size = pretrained_window_size or [0, 0] + self.dim = dim self.window_size = window_size # Wh, Ww self.pretrained_window_size = pretrained_window_size @@ -698,13 +701,17 @@ class Swin2SR(nn.Module): """ def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + embed_dim=96, depths=None, num_heads=None, window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', **kwargs): super(Swin2SR, self).__init__() + + depths = depths or [6, 6, 6, 6] + num_heads = num_heads or [6, 6, 6, 6] + num_in_ch = in_chans num_out_ch = in_chans num_feat = 64 -- cgit v1.2.3