From 12c4d5c6b5bf9dd50d0601c36af4f99b65316d58 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 23:22:22 +0300 Subject: hypernetwork training mk1 --- modules/hypernetwork/ui.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 modules/hypernetwork/ui.py (limited to 'modules/hypernetwork/ui.py') diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py new file mode 100644 index 00000000..525f978c --- /dev/null +++ b/modules/hypernetwork/ui.py @@ -0,0 +1,43 @@ +import html +import os + +import gradio as gr + +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess +from modules import sd_hijack, shared + + +def create_hypernetwork(name): + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernetwork.save(fn) + + shared.reload_hypernetworks() + shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) + + return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" + + +def train_hypernetwork(*args): + + initial_hypernetwork = shared.hypernetwork + + try: + sd_hijack.undo_optimizations() + + hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. +Hypernetwork saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + shared.hypernetwork = initial_hypernetwork + sd_hijack.apply_optimizations() + -- cgit v1.2.3 From 530103b586109c11fd068eb70ef09503ec6a4caf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 14:53:02 +0300 Subject: fixes related to merge --- modules/hypernetwork/ui.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules/hypernetwork/ui.py') diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py index 525f978c..f6d1d0a3 100644 --- a/modules/hypernetwork/ui.py +++ b/modules/hypernetwork/ui.py @@ -6,24 +6,24 @@ import gradio as gr import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess from modules import sd_hijack, shared +from modules.hypernetwork import hypernetwork def create_hypernetwork(name): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" - hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) - hypernetwork.save(fn) + hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet.save(fn) shared.reload_hypernetworks() - shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" def train_hypernetwork(*args): - initial_hypernetwork = shared.hypernetwork + initial_hypernetwork = shared.loaded_hypernetwork try: sd_hijack.undo_optimizations() @@ -38,6 +38,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception: raise finally: - shared.hypernetwork = initial_hypernetwork + shared.loaded_hypernetwork = initial_hypernetwork sd_hijack.apply_optimizations() -- cgit v1.2.3