aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-21 06:47:43 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-10-21 06:47:43 +0000
commitc23f666dba2b484d521d2dc4be91cf9e09312647 (patch)
tree44d68d6a646387973bbe45e9568dc17efaae228a
parenta26fc2834c86d9e90e2d336ba670017552f38d29 (diff)
downloadstable-diffusion-webui-gfx803-c23f666dba2b484d521d2dc4be91cf9e09312647.tar.gz
stable-diffusion-webui-gfx803-c23f666dba2b484d521d2dc4be91cf9e09312647.tar.bz2
stable-diffusion-webui-gfx803-c23f666dba2b484d521d2dc4be91cf9e09312647.zip
a more strict check for activation type and a more reasonable check for type of layer in hypernets
-rw-r--r--modules/hypernetworks/hypernetwork.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 7d617680..84e7e350 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -32,10 +32,16 @@ class HypernetworkModule(torch.nn.Module):
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+
if activation_func == "relu":
linears.append(torch.nn.ReLU())
- if activation_func == "leakyrelu":
+ elif activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU())
+ elif activation_func == 'linear' or activation_func is None:
+ pass
+ else:
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if type(layer) == torch.nn.Linear:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
@@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
- if not "ReLU" in layer.__str__():
+ if type(layer) == torch.nn.Linear:
layer_structure += [layer.weight, layer.bias]
return layer_structure