diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-19 16:21:16 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-19 16:21:16 +0000 |
commit | c6e9fed5003631c87d548e74d6e359678959a453 (patch) | |
tree | 3404d9e47b58f4800b2e1a87427b48c822f4cb0e | |
parent | c664b231a836891d22081c5643c46aace180e427 (diff) | |
download | stable-diffusion-webui-gfx803-c6e9fed5003631c87d548e74d6e359678959a453.tar.gz stable-diffusion-webui-gfx803-c6e9fed5003631c87d548e74d6e359678959a453.tar.bz2 stable-diffusion-webui-gfx803-c6e9fed5003631c87d548e74d6e359678959a453.zip |
fix for #3086 failing to load any previous hypernet
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 60 |
1 files changed, 28 insertions, 32 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d519cd9..74300122 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,11 +24,10 @@ class HypernetworkModule(torch.nn.Module): def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
super().__init__()
- if layer_structure is not None:
- assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
- assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
- else:
- layer_structure = parse_layer_structure(dim, state_dict)
+
+ assert layer_structure is not None, "layer_structure mut not be None"
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
@@ -39,23 +38,30 @@ class HypernetworkModule(torch.nn.Module): self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
- try:
- self.load_state_dict(state_dict)
- except RuntimeError:
- self.try_load_previous(state_dict)
+ self.fix_old_state_dict(state_dict)
+ self.load_state_dict(state_dict)
else:
for layer in self.linear:
- layer.weight.data.normal_(mean = 0.0, std = 0.01)
+ layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
self.to(devices.device)
- def try_load_previous(self, state_dict):
- states = self.state_dict()
- states['linear.0.bias'].copy_(state_dict['linear1.bias'])
- states['linear.0.weight'].copy_(state_dict['linear1.weight'])
- states['linear.1.bias'].copy_(state_dict['linear2.bias'])
- states['linear.1.weight'].copy_(state_dict['linear2.weight'])
+ def fix_old_state_dict(self, state_dict):
+ changes = {
+ 'linear1.bias': 'linear.0.bias',
+ 'linear1.weight': 'linear.0.weight',
+ 'linear2.bias': 'linear.1.bias',
+ 'linear2.weight': 'linear.1.weight',
+ }
+
+ for fr, to in changes.items():
+ x = state_dict.get(fr, None)
+ if x is None:
+ continue
+
+ del state_dict[fr]
+ state_dict[to] = x
def forward(self, x):
return x + self.linear(x) * self.multiplier
@@ -71,18 +77,6 @@ def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
-def parse_layer_structure(dim, state_dict):
- i = 0
- layer_structure = [1]
-
- while (key := "linear.{}.weight".format(i)) in state_dict:
- weight = state_dict[key]
- layer_structure.append(len(weight) // dim)
- i += 1
-
- return layer_structure
-
-
class Hypernetwork:
filename = None
name = None
@@ -135,17 +129,18 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu')
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
+
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
- HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
- HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
+ HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
)
self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0)
- self.layer_structure = state_dict.get('layer_structure', None)
- self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
@@ -244,6 +239,7 @@ def stack_conds(conds): return torch.stack(conds)
+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
|