aboutsummaryrefslogtreecommitdiffstats
path: root/modules/ui.py
diff options
context:
space:
mode:
authoraria1th <35677394+aria1th@users.noreply.github.com>2023-01-10 05:56:57 +0000
committeraria1th <35677394+aria1th@users.noreply.github.com>2023-01-10 05:56:57 +0000
commita4a5475cfa3c68af6cb046081002a72f862ce4be (patch)
tree54deee50926938b7be198b608bcfbdae7e7cb370 /modules/ui.py
parentbd4587d2f5b70ed951d2c17f25a4622fc1cb31c2 (diff)
downloadstable-diffusion-webui-gfx803-a4a5475cfa3c68af6cb046081002a72f862ce4be.tar.gz
stable-diffusion-webui-gfx803-a4a5475cfa3c68af6cb046081002a72f862ce4be.tar.bz2
stable-diffusion-webui-gfx803-a4a5475cfa3c68af6cb046081002a72f862ce4be.zip
Variable dropout rate
Implements variable dropout rate from #4549 Fixes hypernetwork multiplier being able to modified during training, also fixes user-errors by setting multiplier value to lower values for training. Changes function name to match torch.nn.module standard Fixes RNG reset issue when generating previews by restoring RNG state
Diffstat (limited to 'modules/ui.py')
-rw-r--r--modules/ui.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/modules/ui.py b/modules/ui.py
index b6079aec..9b9081b5 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1268,6 +1268,7 @@ def create_ui():
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
+ new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
with gr.Row():
@@ -1414,7 +1415,8 @@ def create_ui():
new_hypernetwork_activation_func,
new_hypernetwork_initialization_option,
new_hypernetwork_add_layer_norm,
- new_hypernetwork_use_dropout
+ new_hypernetwork_use_dropout,
+ new_hypernetwork_dropout_structure
],
outputs=[
train_hypernetwork_name,