diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-21 06:47:43 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-21 06:47:43 +0000 |
commit | c23f666dba2b484d521d2dc4be91cf9e09312647 (patch) | |
tree | 44d68d6a646387973bbe45e9568dc17efaae228a /modules | |
parent | a26fc2834c86d9e90e2d336ba670017552f38d29 (diff) | |
download | stable-diffusion-webui-gfx803-c23f666dba2b484d521d2dc4be91cf9e09312647.tar.gz stable-diffusion-webui-gfx803-c23f666dba2b484d521d2dc4be91cf9e09312647.tar.bz2 stable-diffusion-webui-gfx803-c23f666dba2b484d521d2dc4be91cf9e09312647.zip |
a more strict check for activation type and a more reasonable check for type of layer in hypernets
Diffstat (limited to 'modules')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d617680..84e7e350 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -32,10 +32,16 @@ class HypernetworkModule(torch.nn.Module): linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+
if activation_func == "relu":
linears.append(torch.nn.ReLU())
- if activation_func == "leakyrelu":
+ elif activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU())
+ elif activation_func == 'linear' or activation_func is None:
+ pass
+ else:
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict)
else:
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if type(layer) == torch.nn.Linear:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
@@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module): def trainables(self):
layer_structure = []
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if type(layer) == torch.nn.Linear:
layer_structure += [layer.weight, layer.bias]
return layer_structure
|