diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-11 15:04:47 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-11 15:04:47 +0000 |
commit | d682444ecc99319fbd2b142a12727501e2884ba7 (patch) | |
tree | 0b4cb662da210c35c513e95b2786d2eb4082d6df /modules/hypernetworks | |
parent | 5ba23cb41f28f5856a7f64cb0d95e1e94dce90af (diff) | |
download | stable-diffusion-webui-gfx803-d682444ecc99319fbd2b142a12727501e2884ba7.tar.gz stable-diffusion-webui-gfx803-d682444ecc99319fbd2b142a12727501e2884ba7.tar.bz2 stable-diffusion-webui-gfx803-d682444ecc99319fbd2b142a12727501e2884ba7.zip |
add option to select hypernetwork modules when creating
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 4 | ||||
-rw-r--r-- | modules/hypernetworks/ui.py | 4 |
2 files changed, 4 insertions, 4 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index aa701bda..b081f14e 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -42,7 +42,7 @@ class Hypernetwork: filename = None
name = None
- def __init__(self, name=None):
+ def __init__(self, name=None, enable_sizes=None):
self.filename = None
self.name = name
self.layers = {}
@@ -50,7 +50,7 @@ class Hypernetwork: self.sd_checkpoint = None
self.sd_checkpoint_name = None
- for size in [320, 640, 768, 1280]:
+ for size in enable_sizes or [320, 640, 768, 1280]:
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
def weights(self):
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index e7540f41..cdddcce1 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,11 +9,11 @@ from modules import sd_hijack, shared from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name):
+def create_hypernetwork(name, enable_sizes):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name)
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
hypernet.save(fn)
shared.reload_hypernetworks()
|