aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--modules/hypernetworks/hypernetwork.py72
-rw-r--r--modules/shared.py4
-rw-r--r--webui.py4
3 files changed, 66 insertions, 14 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 4905710e..082165f4 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -19,34 +19,82 @@ from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
+def parse_layer_structure(dim, state_dict):
+ i = 0
+ res = [1]
+ while (key := "linear.{}.weight".format(i)) in state_dict:
+ weight = state_dict[key]
+ res.append(len(weight) // dim)
+ i += 1
+ return res
+
+
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
+ layer_structure = None
+ add_layer_norm = False
def __init__(self, dim, state_dict=None):
super().__init__()
+ if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None:
+ layer_structure = (1, 2, 1)
+ else:
+ if self.layer_structure is not None:
+ assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
+ layer_structure = self.layer_structure
+ else:
+ layer_structure = parse_layer_structure(dim, state_dict)
+
+ linears = []
+ for i in range(len(layer_structure) - 1):
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+ if self.add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
- self.linear1 = torch.nn.Linear(dim, dim * 2)
- self.linear2 = torch.nn.Linear(dim * 2, dim)
+ self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
- self.load_state_dict(state_dict, strict=True)
+ try:
+ self.load_state_dict(state_dict)
+ except RuntimeError:
+ self.try_load_previous(state_dict)
else:
-
- self.linear1.weight.data.normal_(mean=0.0, std=0.01)
- self.linear1.bias.data.zero_()
- self.linear2.weight.data.normal_(mean=0.0, std=0.01)
- self.linear2.bias.data.zero_()
+ for layer in self.linear:
+ layer.weight.data.normal_(mean = 0.0, std = 0.01)
+ layer.bias.data.zero_()
self.to(devices.device)
+ def try_load_previous(self, state_dict):
+ states = self.state_dict()
+ states['linear.0.bias'].copy_(state_dict['linear1.bias'])
+ states['linear.0.weight'].copy_(state_dict['linear1.weight'])
+ states['linear.1.bias'].copy_(state_dict['linear2.bias'])
+ states['linear.1.weight'].copy_(state_dict['linear2.weight'])
+
def forward(self, x):
- return x + (self.linear2(self.linear1(x))) * self.multiplier
+ return x + self.linear(x) * self.multiplier
+
+ def trainables(self):
+ res = []
+ for layer in self.linear:
+ res += [layer.weight, layer.bias]
+ return res
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
+def apply_layer_structure(value=None):
+ HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure
+
+
+def apply_layer_norm(value=None):
+ HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm
+
+
class Hypernetwork:
filename = None
name = None
@@ -68,7 +116,7 @@ class Hypernetwork:
for k, layers in self.layers.items():
for layer in layers:
layer.train()
- res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
+ res += layer.trainables()
return res
@@ -226,7 +274,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
-
+ assert ds.length > 1, "Dataset should contain more than 1 images"
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
@@ -261,7 +309,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
-# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x
diff --git a/modules/shared.py b/modules/shared.py
index f7d66870..0540cae9 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -137,7 +137,7 @@ class State:
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
-
+
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
@@ -260,6 +260,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}),
+ "sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
diff --git a/webui.py b/webui.py
index 71724c3b..c7260c7a 100644
--- a/webui.py
+++ b/webui.py
@@ -85,7 +85,9 @@ def initialize():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
-
+ shared.opts.onchange("sd_hypernetwork_layer_structure", modules.hypernetworks.hypernetwork.apply_layer_structure)
+ shared.opts.onchange("sd_hypernetwork_add_layer_norm", modules.hypernetworks.hypernetwork.apply_layer_norm)
+
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')