diff options
author | yfszzx <yfszzx@gmail.com> | 2022-10-12 13:24:40 +0000 |
---|---|---|
committer | yfszzx <yfszzx@gmail.com> | 2022-10-12 13:24:40 +0000 |
commit | c87c3b9c1169f8a9b632d6d8c8675d98956c387c (patch) | |
tree | eeeb4ff5e05af265686ce3a7916a0df2f30113e4 /modules/hypernetworks/ui.py | |
parent | 511ca57e37483aac0cf260c89838ad2948509101 (diff) | |
parent | 429442f4a6aab7301efb89d27bef524fe827e81a (diff) | |
download | stable-diffusion-webui-gfx803-c87c3b9c1169f8a9b632d6d8c8675d98956c387c.tar.gz stable-diffusion-webui-gfx803-c87c3b9c1169f8a9b632d6d8c8675d98956c387c.tar.bz2 stable-diffusion-webui-gfx803-c87c3b9c1169f8a9b632d6d8c8675d98956c387c.zip |
test
Diffstat (limited to 'modules/hypernetworks/ui.py')
-rw-r--r-- | modules/hypernetworks/ui.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index e7540f41..dfa599af 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -5,15 +5,15 @@ import gradio as gr import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared
+from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name):
+def create_hypernetwork(name, enable_sizes):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name)
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
hypernet.save(fn)
shared.reload_hypernetworks()
@@ -25,6 +25,8 @@ def train_hypernetwork(*args): initial_hypernetwork = shared.loaded_hypernetwork
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
+
try:
sd_hijack.undo_optimizations()
@@ -39,5 +41,7 @@ Hypernetwork saved to {html.escape(filename)} raise
finally:
shared.loaded_hypernetwork = initial_hypernetwork
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()
|