diff options
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4541af18..9388959f 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -154,16 +154,28 @@ class Hypernetwork: HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
+ self.eval_mode()
def weights(self):
res = []
+ for k, layers in self.layers.items():
+ for layer in layers:
+ res += layer.parameters()
+ return res
+ def train_mode(self):
for k, layers in self.layers.items():
for layer in layers:
layer.train()
- res += layer.trainables()
+ for param in layer.parameters():
+ param.requires_grad = True
- return res
+ def eval_mode(self):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
def save(self, filename):
state_dict = {}
@@ -426,8 +438,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, shared.sd_model.first_stage_model.to(devices.cpu)
weights = hypernetwork.weights()
- for weight in weights:
- weight.requires_grad = True
+ hypernetwork.train_mode()
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
@@ -538,7 +549,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
-
+ hypernetwork.eval_mode()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
@@ -571,7 +582,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
-
+ hypernetwork.train_mode()
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
@@ -593,6 +604,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> finally:
pbar.leave = False
pbar.close()
+ hypernetwork.eval_mode()
#report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|