aboutsummaryrefslogtreecommitdiffstats
path: root/modules/hypernetworks/ui.py
diff options
context:
space:
mode:
authoraria1th <35677394+aria1th@users.noreply.github.com>2023-01-10 05:56:57 +0000
committeraria1th <35677394+aria1th@users.noreply.github.com>2023-01-10 05:56:57 +0000
commita4a5475cfa3c68af6cb046081002a72f862ce4be (patch)
tree54deee50926938b7be198b608bcfbdae7e7cb370 /modules/hypernetworks/ui.py
parentbd4587d2f5b70ed951d2c17f25a4622fc1cb31c2 (diff)
downloadstable-diffusion-webui-gfx803-a4a5475cfa3c68af6cb046081002a72f862ce4be.tar.gz
stable-diffusion-webui-gfx803-a4a5475cfa3c68af6cb046081002a72f862ce4be.tar.bz2
stable-diffusion-webui-gfx803-a4a5475cfa3c68af6cb046081002a72f862ce4be.zip
Variable dropout rate
Implements variable dropout rate from #4549 Fixes hypernetwork multiplier being able to modified during training, also fixes user-errors by setting multiplier value to lower values for training. Changes function name to match torch.nn.module standard Fixes RNG reset issue when generating previews by restoring RNG state
Diffstat (limited to 'modules/hypernetworks/ui.py')
-rw-r--r--modules/hypernetworks/ui.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index e7f9e593..81e3f519 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -9,8 +9,8 @@ from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
- filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
+ filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""