diff options
author | aria1th <35677394+aria1th@users.noreply.github.com> | 2023-01-10 05:56:57 +0000 |
---|---|---|
committer | aria1th <35677394+aria1th@users.noreply.github.com> | 2023-01-10 05:56:57 +0000 |
commit | a4a5475cfa3c68af6cb046081002a72f862ce4be (patch) | |
tree | 54deee50926938b7be198b608bcfbdae7e7cb370 /modules | |
parent | bd4587d2f5b70ed951d2c17f25a4622fc1cb31c2 (diff) | |
download | stable-diffusion-webui-gfx803-a4a5475cfa3c68af6cb046081002a72f862ce4be.tar.gz stable-diffusion-webui-gfx803-a4a5475cfa3c68af6cb046081002a72f862ce4be.tar.bz2 stable-diffusion-webui-gfx803-a4a5475cfa3c68af6cb046081002a72f862ce4be.zip |
Variable dropout rate
Implements variable dropout rate from #4549
Fixes hypernetwork multiplier being able to modified during training, also fixes user-errors by setting multiplier value to lower values for training.
Changes function name to match torch.nn.module standard
Fixes RNG reset issue when generating previews by restoring RNG state
Diffstat (limited to 'modules')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 101 | ||||
-rw-r--r-- | modules/hypernetworks/ui.py | 4 | ||||
-rw-r--r-- | modules/ui.py | 4 |
3 files changed, 81 insertions, 28 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index ea3f1db9..300d3975 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -39,7 +39,7 @@ class HypernetworkModule(torch.nn.Module): activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
- add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
+ add_layer_norm=False, activate_output=False, dropout_structure=None):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -64,9 +64,12 @@ class HypernetworkModule(torch.nn.Module): if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
- # Add dropout except last layer
- if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
- linears.append(torch.nn.Dropout(p=0.3))
+ # Everything should be now parsed into dropout structure, and applied here.
+ # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
+ if dropout_structure is not None and dropout_structure[i+1] > 0:
+ assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
+ linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
+ # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
self.linear = torch.nn.Sequential(*linears)
@@ -113,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x
def forward(self, x):
- return x + self.linear(x) * self.multiplier
+ return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
def trainables(self):
layer_structure = []
@@ -126,6 +129,21 @@ class HypernetworkModule(torch.nn.Module): def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
+#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
+def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
+ if layer_structure is None:
+ layer_structure = [1, 2, 1]
+ if not use_dropout:
+ return [0] * len(layer_structure)
+ dropout_values = [0]
+ dropout_values.extend([0.3] * (len(layer_structure) - 3))
+ if last_layer_dropout:
+ dropout_values.append(0.3)
+ else:
+ dropout_values.append(0)
+ dropout_values.append(0)
+ return dropout_values
+
class Hypernetwork:
filename = None
@@ -144,18 +162,22 @@ class Hypernetwork: self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
self.activate_output = activate_output
- self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
+ self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
+ self.dropout_structure = kwargs.get('dropout_structure', None)
+ if self.dropout_structure is None:
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
self.optimizer_name = None
self.optimizer_state_dict = None
+ self.optional_info = None
for size in enable_sizes or []:
self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
)
- self.eval_mode()
+ self.eval()
def weights(self):
res = []
@@ -164,14 +186,14 @@ class Hypernetwork: res += layer.parameters()
return res
- def train_mode(self):
+ def train(self, mode=True):
for k, layers in self.layers.items():
for layer in layers:
- layer.train()
+ layer.train(mode=mode)
for param in layer.parameters():
- param.requires_grad = True
+ param.requires_grad = mode
- def eval_mode(self):
+ def eval(self):
for k, layers in self.layers.items():
for layer in layers:
layer.eval()
@@ -191,11 +213,13 @@ class Hypernetwork: state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['weight_initialization'] = self.weight_init
- state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output
- state_dict['last_layer_dropout'] = self.last_layer_dropout
+ state_dict['use_dropout'] = self.use_dropout
+ state_dict['dropout_structure'] = self.dropout_structure
+ state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
+ state_dict['optional_info'] = self.optional_info if self.optional_info else None
if self.optimizer_name is not None:
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
@@ -215,43 +239,56 @@ class Hypernetwork: self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
print(self.layer_structure)
+ optional_info = state_dict.get('optional_info', None)
+ if optional_info is not None:
+ print(f"INFO:\n {optional_info}\n")
+ self.optional_info = optional_info
self.activation_func = state_dict.get('activation_func', None)
print(f"Activation function is {self.activation_func}")
self.weight_init = state_dict.get('weight_initialization', 'Normal')
print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False)
print(f"Layer norm is set to {self.add_layer_norm}")
- self.use_dropout = state_dict.get('use_dropout', False)
+ self.dropout_structure = state_dict.get('dropout_structure', None)
+ self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
+ # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
+ if self.dropout_structure is None:
+ print("Using previous dropout structure")
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
+ print(f"Dropout structure is set to {self.dropout_structure}")
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
- self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
- print(f"Optimizer name is {self.optimizer_name}")
+
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else:
self.optimizer_state_dict = None
if self.optimizer_state_dict:
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print("Loaded existing optimizer from checkpoint")
+ print(f"Optimizer name is {self.optimizer_name}")
else:
+ self.optimizer_name = "AdamW"
print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
- self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
)
self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
+ self.eval()
def list_hypernetworks(path):
@@ -379,9 +416,10 @@ def report_statistics(loss_info:dict): print(e)
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+ assert name, "Name cannot be empty!"
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
if not overwrite_old:
@@ -390,6 +428,11 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
+ if use_dropout and dropout_structure and type(dropout_structure) == str:
+ dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
+ else:
+ dropout_structure = [0] * len(layer_structure)
+
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name,
enable_sizes=[int(x) for x in enable_sizes],
@@ -398,6 +441,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, weight_init=weight_init,
add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
+ dropout_structure=dropout_structure
)
hypernet.save(fn)
@@ -480,7 +524,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, shared.sd_model.first_stage_model.to(devices.cpu)
weights = hypernetwork.weights()
- hypernetwork.train_mode()
+ hypernetwork.train()
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
@@ -594,7 +638,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
- hypernetwork.eval_mode()
+ hypernetwork.eval()
+ rng_state = torch.get_rng_state()
+ cuda_rng_state = None
+ if torch.cuda.is_available():
+ cuda_rng_state = torch.cuda.get_rng_state_all()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
@@ -627,7 +675,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
- hypernetwork.train_mode()
+ torch.set_rng_state(rng_state)
+ if torch.cuda.is_available():
+ torch.cuda.set_rng_state_all(cuda_rng_state)
+ hypernetwork.train()
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
@@ -649,7 +700,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> finally:
pbar.leave = False
pbar.close()
- hypernetwork.eval_mode()
+ hypernetwork.eval()
#report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index e7f9e593..81e3f519 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,8 +9,8 @@ from modules import devices, sd_hijack, shared not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
- filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
+ filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
diff --git a/modules/ui.py b/modules/ui.py index b6079aec..9b9081b5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1268,6 +1268,7 @@ def create_ui(): new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
+ new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
with gr.Row():
@@ -1414,7 +1415,8 @@ def create_ui(): new_hypernetwork_activation_func,
new_hypernetwork_initialization_option,
new_hypernetwork_add_layer_norm,
- new_hypernetwork_use_dropout
+ new_hypernetwork_use_dropout,
+ new_hypernetwork_dropout_structure
],
outputs=[
train_hypernetwork_name,
|