diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-05-10 18:21:32 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-05-10 18:21:32 +0000 |
commit | 3ec7b705c78b7aca9569c92a419837352c7a4ec6 (patch) | |
tree | 98248bc21aa4ad9715205f0a65a654532c6cfcc0 /extensions-builtin/SwinIR | |
parent | d25219b7e889cf34bccae9cb88497708796efda2 (diff) | |
download | stable-diffusion-webui-gfx803-3ec7b705c78b7aca9569c92a419837352c7a4ec6.tar.gz stable-diffusion-webui-gfx803-3ec7b705c78b7aca9569c92a419837352c7a4ec6.tar.bz2 stable-diffusion-webui-gfx803-3ec7b705c78b7aca9569c92a419837352c7a4ec6.zip |
suggestions and fixes from the PR
Diffstat (limited to 'extensions-builtin/SwinIR')
-rw-r--r-- | extensions-builtin/SwinIR/swinir_model_arch.py | 6 | ||||
-rw-r--r-- | extensions-builtin/SwinIR/swinir_model_arch_v2.py | 11 |
2 files changed, 3 insertions, 14 deletions
diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py index de195d9b..73e37cfa 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ b/extensions-builtin/SwinIR/swinir_model_arch.py @@ -644,17 +644,13 @@ class SwinIR(nn.Module): """ def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=None, num_heads=None, + embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', **kwargs): super(SwinIR, self).__init__() - - depths = depths or [6, 6, 6, 6] - num_heads = num_heads or [6, 6, 6, 6] - num_in_ch = in_chans num_out_ch = in_chans num_feat = 64 diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py index 15777af9..3ca9be78 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py @@ -74,12 +74,9 @@ class WindowAttention(nn.Module): """
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=None):
+ pretrained_window_size=(0, 0)):
super().__init__()
-
- pretrained_window_size = pretrained_window_size or [0, 0]
-
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
@@ -701,17 +698,13 @@ class Swin2SR(nn.Module): """
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=None, num_heads=None,
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(Swin2SR, self).__init__()
-
- depths = depths or [6, 6, 6, 6]
- num_heads = num_heads or [6, 6, 6, 6]
-
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
|