From 7912acef725832debef58c4c7bf8ec22fb446c0b Mon Sep 17 00:00:00 2001 From: discus0434 Date: Sat, 22 Oct 2022 13:00:44 +0000 Subject: small fix --- modules/hypernetworks/hypernetwork.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) (limited to 'modules/hypernetworks') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3132a56c..7d12e0ff 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -42,22 +42,20 @@ class HypernetworkModule(torch.nn.Module): # Add an activation func if activation_func == "linear" or activation_func is None: pass - # If ReLU, Skip adding it to the first layer to avoid dying ReLU - elif activation_func == "relu" and i < 1: - pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) else: raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') - # Add dropout - if use_dropout: - linears.append(torch.nn.Dropout(p=0.3)) - # Add layer normalization if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) + # Add dropout + if use_dropout: + p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2 + linears.append(torch.nn.Dropout(p=p)) + self.linear = torch.nn.Sequential(*linears) if state_dict is not None: -- cgit v1.2.3