diff options
author | captin411 <captindave@gmail.com> | 2022-10-25 13:22:27 -0700 |
---|---|---|
committer | captin411 <captindave@gmail.com> | 2022-10-25 13:22:27 -0700 |
commit | 6629446a2f9bb3ade1c271854aae1530ba1a8cc3 (patch) | |
tree | ad7cfd2b3f0208c24da64c7f08e0550e783228ec /modules/hypernetworks/ui.py | |
parent | 3e6c2420c1177e9e79f2b566a5a7795b7416e34a (diff) | |
parent | 3e15f8e0f5cc87507f77546d92435670644dbd18 (diff) | |
download | stable-diffusion-webui-gfx803-6629446a2f9bb3ade1c271854aae1530ba1a8cc3.tar.gz |
Merge branch 'master' into focal-point-cropping
Diffstat (limited to 'modules/hypernetworks/ui.py')
-rw-r--r-- | modules/hypernetworks/ui.py | 28 |
1 files changed, 21 insertions, 7 deletions
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index dfa599af..2b472d87 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -1,19 +1,33 @@ import html
import os
+import re
import gradio as gr
-
-import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared, devices
+import modules.textual_inversion.textual_inversion
+from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name, enable_sizes):
- fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ if type(layer_structure) == str:
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
+
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
+ name=name,
+ enable_sizes=[int(x) for x in enable_sizes],
+ layer_structure=layer_structure,
+ activation_func=activation_func,
+ add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
+ )
hypernet.save(fn)
shared.reload_hypernetworks()
|