From 873efeed49bb5197a42da18272115b326c5d68f3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 15:51:22 +0300 Subject: rename hypernetwork dir to hypernetworks to prevent clash with an old filename that people who use zip instead of git clone will have --- modules/hypernetworks/ui.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 modules/hypernetworks/ui.py (limited to 'modules/hypernetworks/ui.py') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py new file mode 100644 index 00000000..811bc31e --- /dev/null +++ b/modules/hypernetworks/ui.py @@ -0,0 +1,43 @@ +import html +import os + +import gradio as gr + +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess +from modules import sd_hijack, shared +from modules.hypernetworks import hypernetwork + + +def create_hypernetwork(name): + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet.save(fn) + + shared.reload_hypernetworks() + + return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" + + +def train_hypernetwork(*args): + + initial_hypernetwork = shared.loaded_hypernetwork + + try: + sd_hijack.undo_optimizations() + + hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. +Hypernetwork saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + shared.loaded_hypernetwork = initial_hypernetwork + sd_hijack.apply_optimizations() + -- cgit v1.2.3 From b0583be0884cd17dafb408fd79b52b2a0a972563 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 15:54:34 +0300 Subject: more renames --- modules/hypernetworks/ui.py | 4 ++-- modules/ui.py | 4 ++-- webui.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) (limited to 'modules/hypernetworks/ui.py') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 811bc31e..e7540f41 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -13,7 +13,7 @@ def create_hypernetwork(name): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" - hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name) hypernet.save(fn) shared.reload_hypernetworks() @@ -28,7 +28,7 @@ def train_hypernetwork(*args): try: sd_hijack.undo_optimizations() - hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) + hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args) res = f""" Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. diff --git a/modules/ui.py b/modules/ui.py index 42e5d866..ee333c3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1111,7 +1111,7 @@ def create_ui(wrap_gradio_gpu_call): ) create_hypernetwork.click( - fn=modules.hypernetwork.ui.create_hypernetwork, + fn=modules.hypernetworks.ui.create_hypernetwork, inputs=[ new_hypernetwork_name, ], @@ -1164,7 +1164,7 @@ def create_ui(wrap_gradio_gpu_call): ) train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetwork.ui.train_hypernetwork, extra_outputs=[gr.update()]), + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ train_hypernetwork_name, diff --git a/webui.py b/webui.py index faa38a0d..338f58e1 100644 --- a/webui.py +++ b/webui.py @@ -83,7 +83,7 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts")) shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) -shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) +shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) def webui(): -- cgit v1.2.3