From 12c4d5c6b5bf9dd50d0601c36af4f99b65316d58 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Fri, 7 Oct 2022 23:22:22 +0300
Subject: hypernetwork training mk1
---
modules/ui.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++-------
1 file changed, 51 insertions(+), 7 deletions(-)
(limited to 'modules/ui.py')
diff --git a/modules/ui.py b/modules/ui.py
index 4f18126f..051908c1 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -37,6 +37,7 @@ import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
+import modules.hypernetwork.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
@@ -965,6 +966,18 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
create_embedding = gr.Button(value="Create", variant='primary')
+ with gr.Group():
+ gr.HTML(value="
Create a new hypernetwork
")
+
+ new_hypernetwork_name = gr.Textbox(label="Name")
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_hypernetwork = gr.Button(value="Create", variant='primary')
+
with gr.Group():
gr.HTML(value="Preprocess images
")
@@ -986,6 +999,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
gr.HTML(value="Train an embedding; must specify a directory with a set of 512x512 images
")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
@@ -993,15 +1007,12 @@ def create_ui(wrap_gradio_gpu_call):
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
+ preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
with gr.Row():
- with gr.Column(scale=2):
- gr.HTML(value="")
-
- with gr.Column():
- with gr.Row():
- interrupt_training = gr.Button(value="Interrupt")
- train_embedding = gr.Button(value="Train", variant='primary')
+ interrupt_training = gr.Button(value="Interrupt")
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
+ train_embedding = gr.Button(value="Train Embedding", variant='primary')
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
@@ -1027,6 +1038,18 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ create_hypernetwork.click(
+ fn=modules.hypernetwork.ui.create_hypernetwork,
+ inputs=[
+ new_hypernetwork_name,
+ ],
+ outputs=[
+ train_hypernetwork_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
run_preprocess.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
@@ -1062,12 +1085,33 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ train_hypernetwork.click(
+ fn=wrap_gradio_gpu_call(modules.hypernetwork.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ train_hypernetwork_name,
+ learn_rate,
+ dataset_directory,
+ log_directory,
+ steps,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ preview_image_prompt,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
+
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
--
cgit v1.2.3
From f347ddfd808c56bb1bacdec0c4bedf826ff85cd8 Mon Sep 17 00:00:00 2001
From: RW21
Date: Mon, 10 Oct 2022 10:44:11 +0900
Subject: Remove max_batch_count from ui.py
---
modules/ui.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
(limited to 'modules/ui.py')
diff --git a/modules/ui.py b/modules/ui.py
index 8c06ad7c..8ba84911 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -524,7 +524,7 @@ def create_ui(wrap_gradio_gpu_call):
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
with gr.Row():
- batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
@@ -710,7 +710,7 @@ def create_ui(wrap_gradio_gpu_call):
tiling = gr.Checkbox(label='Tiling', value=False)
with gr.Row():
- batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
with gr.Group():
--
cgit v1.2.3
From 9d33baba587637815d818e5e641d8f8b74c4900d Mon Sep 17 00:00:00 2001
From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com>
Date: Mon, 10 Oct 2022 18:46:48 +0300
Subject: Always show previous mask and fix extras_send dest
---
modules/ui.py | 2 +-
style.css | 7 +++++++
2 files changed, 8 insertions(+), 1 deletion(-)
(limited to 'modules/ui.py')
diff --git a/modules/ui.py b/modules/ui.py
index 8ba84911..e8039d76 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -961,7 +961,7 @@ def create_ui(wrap_gradio_gpu_call):
extras_send_to_inpaint.click(
fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_img2img",
+ _js="extract_image_from_gallery_inpaint",
inputs=[result_images],
outputs=[init_img_with_mask],
)
diff --git a/style.css b/style.css
index 04bb9576..00a3d07f 100644
--- a/style.css
+++ b/style.css
@@ -467,3 +467,10 @@ input[type="range"]{
max-width: 32em;
padding: 0;
}
+
+canvas[key="mask"] {
+ z-index: 12 !important;
+ filter: invert();
+ mix-blend-mode: multiply;
+ pointer-events: none;
+}
--
cgit v1.2.3
From 1add3cff84b7e2436d69b1e97ae689281e4a7c33 Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Mon, 10 Oct 2022 19:57:43 -0500
Subject: Refresh list of models/ckpts upon hitting restart gradio in the
settings pane
---
modules/ui.py | 4 ++++
1 file changed, 4 insertions(+)
(limited to 'modules/ui.py')
diff --git a/modules/ui.py b/modules/ui.py
index e8039d76..06ff118f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -39,6 +39,7 @@ import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
+from modules.sd_models import list_models
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -1290,6 +1291,9 @@ Requested path was: {f}
shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True
+ # refresh models so that new models/.ckpt's show up on reload
+ list_models()
+
restart_gradio.click(
fn=request_restart,
inputs=[],
--
cgit v1.2.3
From 8b7d3f1bef47bbe048f644ed0d8dd3ad46554045 Mon Sep 17 00:00:00 2001
From: Jairo Correa
Date: Tue, 11 Oct 2022 02:22:46 -0300
Subject: Make the ctrl+enter shortcut use the generate button on the current
tab
---
modules/ui.py | 2 +-
script.js | 11 +++++++++--
2 files changed, 10 insertions(+), 3 deletions(-)
(limited to 'modules/ui.py')
diff --git a/modules/ui.py b/modules/ui.py
index e8039d76..cafda884 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1331,7 +1331,7 @@ Requested path was: {f}
with gr.Tabs() as tabs:
for interface, label, ifid in interfaces:
- with gr.TabItem(label, id=ifid):
+ with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render()
if os.path.exists(os.path.join(script_path, "notification.mp3")):
diff --git a/script.js b/script.js
index a92c0f77..9543cbe6 100644
--- a/script.js
+++ b/script.js
@@ -6,6 +6,10 @@ function get_uiCurrentTab() {
return gradioApp().querySelector('.tabs button:not(.border-transparent)')
}
+function get_uiCurrentTabContent() {
+ return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])')
+}
+
uiUpdateCallbacks = []
uiTabChangeCallbacks = []
let uiCurrentTab = null
@@ -50,8 +54,11 @@ document.addEventListener("DOMContentLoaded", function() {
} else if (e.keyCode !== undefined) {
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true;
}
- if (handled) {
- gradioApp().querySelector("#txt2img_generate").click();
+ if (handled) {
+ button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
+ if (button) {
+ button.click();
+ }
e.preventDefault();
}
})
--
cgit v1.2.3
From 8617396c6df71074c7fd3d39419802026874712a Mon Sep 17 00:00:00 2001
From: Kenneth
Date: Mon, 10 Oct 2022 17:23:07 -0600
Subject: Added slider for deepbooru score threshold in settings
---
modules/shared.py | 1 +
modules/ui.py | 2 +-
2 files changed, 2 insertions(+), 1 deletion(-)
(limited to 'modules/ui.py')
diff --git a/modules/shared.py b/modules/shared.py
index ecd15ef5..e0830e28 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -239,6 +239,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
+ "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
}))
options_templates.update(options_section(('ui', "User interface"), {
diff --git a/modules/ui.py b/modules/ui.py
index cafda884..ca3151c4 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -311,7 +311,7 @@ def interrogate(image):
def interrogate_deepbooru(image):
- prompt = get_deepbooru_tags(image)
+ prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold)
return gr_show(True) if prompt is None else prompt
--
cgit v1.2.3
From 530103b586109c11fd068eb70ef09503ec6a4caf Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 11 Oct 2022 14:53:02 +0300
Subject: fixes related to merge
---
modules/hypernetwork.py | 103 -------------------------
modules/hypernetwork/hypernetwork.py | 74 +++++++++++-------
modules/hypernetwork/ui.py | 10 +--
modules/sd_hijack_optimizations.py | 3 +-
modules/shared.py | 13 +++-
modules/textual_inversion/textual_inversion.py | 12 +--
modules/ui.py | 5 +-
scripts/xy_grid.py | 3 +-
webui.py | 15 +---
9 files changed, 78 insertions(+), 160 deletions(-)
delete mode 100644 modules/hypernetwork.py
(limited to 'modules/ui.py')
diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py
deleted file mode 100644
index 7bbc443e..00000000
--- a/modules/hypernetwork.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import glob
-import os
-import sys
-import traceback
-
-import torch
-
-from ldm.util import default
-from modules import devices, shared
-import torch
-from torch import einsum
-from einops import rearrange, repeat
-
-
-class HypernetworkModule(torch.nn.Module):
- def __init__(self, dim, state_dict):
- super().__init__()
-
- self.linear1 = torch.nn.Linear(dim, dim * 2)
- self.linear2 = torch.nn.Linear(dim * 2, dim)
-
- self.load_state_dict(state_dict, strict=True)
- self.to(devices.device)
-
- def forward(self, x):
- return x + (self.linear2(self.linear1(x)))
-
-
-class Hypernetwork:
- filename = None
- name = None
-
- def __init__(self, filename):
- self.filename = filename
- self.name = os.path.splitext(os.path.basename(filename))[0]
- self.layers = {}
-
- state_dict = torch.load(filename, map_location='cpu')
- for size, sd in state_dict.items():
- self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
-
-
-def list_hypernetworks(path):
- res = {}
- for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
- name = os.path.splitext(os.path.basename(filename))[0]
- res[name] = filename
- return res
-
-
-def load_hypernetwork(filename):
- path = shared.hypernetworks.get(filename, None)
- if path is not None:
- print(f"Loading hypernetwork {filename}")
- try:
- shared.loaded_hypernetwork = Hypernetwork(path)
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if shared.loaded_hypernetwork is not None:
- print(f"Unloading hypernetwork")
-
- shared.loaded_hypernetwork = None
-
-
-def apply_hypernetwork(hypernetwork, context):
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is None:
- return context, context
-
- context_k = hypernetwork_layers[0](context)
- context_v = hypernetwork_layers[1](context)
- return context_k, context_v
-
-
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context)
- k = self.to_k(context_k)
- v = self.to_v(context_v)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
-
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
-
- if mask is not None:
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
-
- # attention, what we cannot get enough of
- attn = sim.softmax(dim=-1)
-
- out = einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return self.to_out(out)
diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py
index a3d6a47e..aa701bda 100644
--- a/modules/hypernetwork/hypernetwork.py
+++ b/modules/hypernetwork/hypernetwork.py
@@ -26,10 +26,11 @@ class HypernetworkModule(torch.nn.Module):
if state_dict is not None:
self.load_state_dict(state_dict, strict=True)
else:
- self.linear1.weight.data.fill_(0.0001)
- self.linear1.bias.data.fill_(0.0001)
- self.linear2.weight.data.fill_(0.0001)
- self.linear2.bias.data.fill_(0.0001)
+
+ self.linear1.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear1.bias.data.zero_()
+ self.linear2.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear2.bias.data.zero_()
self.to(devices.device)
@@ -92,41 +93,54 @@ class Hypernetwork:
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
-def load_hypernetworks(path):
+def list_hypernetworks(path):
res = {}
+ for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
+ name = os.path.splitext(os.path.basename(filename))[0]
+ res[name] = filename
+ return res
- for filename in glob.iglob(path + '**/*.pt', recursive=True):
+
+def load_hypernetwork(filename):
+ path = shared.hypernetworks.get(filename, None)
+ if path is not None:
+ print(f"Loading hypernetwork {filename}")
try:
- hn = Hypernetwork()
- hn.load(filename)
- res[hn.name] = hn
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
except Exception:
- print(f"Error loading hypernetwork {filename}", file=sys.stderr)
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ else:
+ if shared.loaded_hypernetwork is not None:
+ print(f"Unloading hypernetwork")
- return res
+ shared.loaded_hypernetwork = None
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
- h = self.heads
+def apply_hypernetwork(hypernetwork, context, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
- q = self.to_q(x)
- context = default(context, x)
+ if hypernetwork_layers is None:
+ return context, context
- hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None)
+ if layer is not None:
+ layer.hyper_k = hypernetwork_layers[0]
+ layer.hyper_v = hypernetwork_layers[1]
- if hypernetwork_layers is not None:
- hypernetwork_k, hypernetwork_v = hypernetwork_layers
+ context_k = hypernetwork_layers[0](context)
+ context_v = hypernetwork_layers[1](context)
+ return context_k, context_v
- self.hypernetwork_k = hypernetwork_k
- self.hypernetwork_v = hypernetwork_v
- context_k = hypernetwork_k(context)
- context_v = hypernetwork_v(context)
- else:
- context_k = context
- context_v = context
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
k = self.to_k(context_k)
v = self.to_v(context_v)
@@ -151,7 +165,9 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
assert hypernetwork_name, 'embedding not selected'
- shared.hypernetwork = shared.hypernetworks[hypernetwork_name]
+ path = shared.hypernetworks.get(hypernetwork_name, None)
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
@@ -176,9 +192,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
- hypernetwork = shared.hypernetworks[hypernetwork_name]
+ hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
@@ -194,7 +210,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps:
return hypernetwork, filename
- pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text) in pbar:
hypernetwork.step = i + ititial_step
diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py
index 525f978c..f6d1d0a3 100644
--- a/modules/hypernetwork/ui.py
+++ b/modules/hypernetwork/ui.py
@@ -6,24 +6,24 @@ import gradio as gr
import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
+from modules.hypernetwork import hypernetwork
def create_hypernetwork(name):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
- hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
- hypernetwork.save(fn)
+ hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
+ hypernet.save(fn)
shared.reload_hypernetworks()
- shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
def train_hypernetwork(*args):
- initial_hypernetwork = shared.hypernetwork
+ initial_hypernetwork = shared.loaded_hypernetwork
try:
sd_hijack.undo_optimizations()
@@ -38,6 +38,6 @@ Hypernetwork saved to {html.escape(filename)}
except Exception:
raise
finally:
- shared.hypernetwork = initial_hypernetwork
+ shared.loaded_hypernetwork = initial_hypernetwork
sd_hijack.apply_optimizations()
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 25cb67a4..27e571fc 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -8,7 +8,8 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared, hypernetwork
+from modules import shared
+from modules.hypernetwork import hypernetwork
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
diff --git a/modules/shared.py b/modules/shared.py
index 14b40d70..8753015e 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,7 +13,8 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers, hypernetwork
+from modules import sd_samplers
+from modules.hypernetwork import hypernetwork
from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
@@ -29,6 +30,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
@@ -82,10 +84,17 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False
config_filename = cmd_opts.ui_settings_file
-hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks'))
+hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
+def reload_hypernetworks():
+ global hypernetworks
+
+ hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
+ hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
+
+
class State:
skipped = False
interrupted = False
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5965c5a0..d6977950 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -238,12 +238,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
+ preview_text = text if preview_image_prompt == "" else preview_image_prompt
+
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=text,
+ prompt=preview_text,
steps=20,
- height=training_height,
- width=training_width,
+ height=training_height,
+ width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
@@ -254,7 +256,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.current_image = image
image.save(last_saved_image)
- last_saved_image += f", prompt: {text}"
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
diff --git a/modules/ui.py b/modules/ui.py
index 10b1ee3a..df653059 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1023,7 +1023,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="")
with gr.Column():
- create_embedding = gr.Button(value="Create", variant='primary')
+ create_embedding = gr.Button(value="Create embedding", variant='primary')
with gr.Group():
gr.HTML(value="Create a new hypernetwork
")
@@ -1035,7 +1035,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="")
with gr.Column():
- create_hypernetwork = gr.Button(value="Create", variant='primary')
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
with gr.Group():
gr.HTML(value="Preprocess images
")
@@ -1147,6 +1147,7 @@ def create_ui(wrap_gradio_gpu_call):
create_image_every,
save_embedding_every,
template_file,
+ preview_image_prompt,
],
outputs=[
ti_output,
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 42e1489c..0af5993c 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -10,7 +10,8 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
-from modules import images, hypernetwork
+from modules import images
+from modules.hypernetwork import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
diff --git a/webui.py b/webui.py
index 7c200551..ba2156c8 100644
--- a/webui.py
+++ b/webui.py
@@ -29,6 +29,7 @@ from modules import devices
from modules import modelloader
from modules.paths import script_path
from modules.shared import cmd_opts
+import modules.hypernetwork.hypernetwork
modelloader.cleanup_models()
modules.sd_models.setup_model()
@@ -77,22 +78,12 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
-def set_hypernetwork():
- shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
-
-
-shared.reload_hypernetworks()
-shared.opts.onchange("sd_hypernetwork", set_hypernetwork)
-set_hypernetwork()
-
-
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
-loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)
-shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
+shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
def webui():
@@ -117,7 +108,7 @@ def webui():
prevent_thread_lock=True
)
- app.add_middleware(GZipMiddleware,minimum_size=1000)
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
while 1:
time.sleep(0.5)
--
cgit v1.2.3
From 87b77cad5f3017c952a7dfec0e7904a9df5b72fd Mon Sep 17 00:00:00 2001
From: Ben <110583491+TheLastBen@users.noreply.github.com>
Date: Mon, 10 Oct 2022 19:37:16 +0100
Subject: Layout fix
---
modules/ui.py | 28 ++++++++++++++--------------
1 file changed, 14 insertions(+), 14 deletions(-)
(limited to 'modules/ui.py')
diff --git a/modules/ui.py b/modules/ui.py
index df653059..de4cd7f2 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -550,15 +550,15 @@ def create_ui(wrap_gradio_gpu_call):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
- with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
+ with gr.Row():
+ do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -738,15 +738,15 @@ def create_ui(wrap_gradio_gpu_call):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
- with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
+ with gr.Row():
+ do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
--
cgit v1.2.3
From 861297cefe2bb663f4e09dd4778a4cb93ebe8ff1 Mon Sep 17 00:00:00 2001
From: Ben <110583491+TheLastBen@users.noreply.github.com>
Date: Tue, 11 Oct 2022 08:08:45 +0100
Subject: add a space holder
---
modules/ui.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
(limited to 'modules/ui.py')
diff --git a/modules/ui.py b/modules/ui.py
index de4cd7f2..fc0f3d3c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -429,7 +429,10 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=8):
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
+ with gr.Row():
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
+ with gr.Column(scale=1, elem_id="roll_col"):
+ sh = gr.Button(elem_id="sh", visible=True)
with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
--
cgit v1.2.3
From 59925644480b6fd84f6bb84b4df7d4fbc6a0cce8 Mon Sep 17 00:00:00 2001
From: JamnedZ
Date: Tue, 11 Oct 2022 16:40:27 +0700
Subject: Cleaned ngrok integration
---
modules/ngrok.py | 15 +++++++++++++++
modules/shared.py | 1 +
modules/ui.py | 5 +++++
3 files changed, 21 insertions(+)
create mode 100644 modules/ngrok.py
(limited to 'modules/ui.py')
diff --git a/modules/ngrok.py b/modules/ngrok.py
new file mode 100644
index 00000000..17e6976f
--- /dev/null
+++ b/modules/ngrok.py
@@ -0,0 +1,15 @@
+from pyngrok import ngrok, conf, exception
+
+
+def connect(token, port):
+ if token == None:
+ token = 'None'
+ conf.get_default().auth_token = token
+ try:
+ public_url = ngrok.connect(port).public_url
+ except exception.PyngrokNgrokError:
+ print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
+ f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
+ else:
+ print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
+ 'You can use this link after the launch is complete.')
\ No newline at end of file
diff --git a/modules/shared.py b/modules/shared.py
index 8753015e..375e3afb 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -38,6 +38,7 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
+parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
diff --git a/modules/ui.py b/modules/ui.py
index fc0f3d3c..f57f32db 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -51,6 +51,11 @@ if not cmd_opts.share and not cmd_opts.listen:
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
+if cmd_opts.ngrok != None:
+ import modules.ngrok as ngrok
+ print('ngrok authtoken detected, trying to connect...')
+ ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860)
+
def gr_show(visible=True):
return {"visible": visible, "__type__": "update"}
--
cgit v1.2.3
From 873efeed49bb5197a42da18272115b326c5d68f3 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 11 Oct 2022 15:51:22 +0300
Subject: rename hypernetwork dir to hypernetworks to prevent clash with an old
filename that people who use zip instead of git clone will have
---
modules/hypernetwork/hypernetwork.py | 283 ----------------------------------
modules/hypernetwork/ui.py | 43 ------
modules/hypernetworks/hypernetwork.py | 283 ++++++++++++++++++++++++++++++++++
modules/hypernetworks/ui.py | 43 ++++++
modules/sd_hijack.py | 2 +-
modules/sd_hijack_optimizations.py | 2 +-
modules/shared.py | 2 +-
modules/ui.py | 2 +-
scripts/xy_grid.py | 2 +-
webui.py | 2 +-
10 files changed, 332 insertions(+), 332 deletions(-)
delete mode 100644 modules/hypernetwork/hypernetwork.py
delete mode 100644 modules/hypernetwork/ui.py
create mode 100644 modules/hypernetworks/hypernetwork.py
create mode 100644 modules/hypernetworks/ui.py
(limited to 'modules/ui.py')
diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py
deleted file mode 100644
index aa701bda..00000000
--- a/modules/hypernetwork/hypernetwork.py
+++ /dev/null
@@ -1,283 +0,0 @@
-import datetime
-import glob
-import html
-import os
-import sys
-import traceback
-import tqdm
-
-import torch
-
-from ldm.util import default
-from modules import devices, shared, processing, sd_models
-import torch
-from torch import einsum
-from einops import rearrange, repeat
-import modules.textual_inversion.dataset
-
-
-class HypernetworkModule(torch.nn.Module):
- def __init__(self, dim, state_dict=None):
- super().__init__()
-
- self.linear1 = torch.nn.Linear(dim, dim * 2)
- self.linear2 = torch.nn.Linear(dim * 2, dim)
-
- if state_dict is not None:
- self.load_state_dict(state_dict, strict=True)
- else:
-
- self.linear1.weight.data.normal_(mean=0.0, std=0.01)
- self.linear1.bias.data.zero_()
- self.linear2.weight.data.normal_(mean=0.0, std=0.01)
- self.linear2.bias.data.zero_()
-
- self.to(devices.device)
-
- def forward(self, x):
- return x + (self.linear2(self.linear1(x)))
-
-
-class Hypernetwork:
- filename = None
- name = None
-
- def __init__(self, name=None):
- self.filename = None
- self.name = name
- self.layers = {}
- self.step = 0
- self.sd_checkpoint = None
- self.sd_checkpoint_name = None
-
- for size in [320, 640, 768, 1280]:
- self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
-
- def weights(self):
- res = []
-
- for k, layers in self.layers.items():
- for layer in layers:
- layer.train()
- res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
-
- return res
-
- def save(self, filename):
- state_dict = {}
-
- for k, v in self.layers.items():
- state_dict[k] = (v[0].state_dict(), v[1].state_dict())
-
- state_dict['step'] = self.step
- state_dict['name'] = self.name
- state_dict['sd_checkpoint'] = self.sd_checkpoint
- state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
-
- torch.save(state_dict, filename)
-
- def load(self, filename):
- self.filename = filename
- if self.name is None:
- self.name = os.path.splitext(os.path.basename(filename))[0]
-
- state_dict = torch.load(filename, map_location='cpu')
-
- for size, sd in state_dict.items():
- if type(size) == int:
- self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
-
- self.name = state_dict.get('name', self.name)
- self.step = state_dict.get('step', 0)
- self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
- self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
-
-
-def list_hypernetworks(path):
- res = {}
- for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
- name = os.path.splitext(os.path.basename(filename))[0]
- res[name] = filename
- return res
-
-
-def load_hypernetwork(filename):
- path = shared.hypernetworks.get(filename, None)
- if path is not None:
- print(f"Loading hypernetwork {filename}")
- try:
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
-
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if shared.loaded_hypernetwork is not None:
- print(f"Unloading hypernetwork")
-
- shared.loaded_hypernetwork = None
-
-
-def apply_hypernetwork(hypernetwork, context, layer=None):
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is None:
- return context, context
-
- if layer is not None:
- layer.hyper_k = hypernetwork_layers[0]
- layer.hyper_v = hypernetwork_layers[1]
-
- context_k = hypernetwork_layers[0](context)
- context_v = hypernetwork_layers[1](context)
- return context_k, context_v
-
-
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
-
- context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
- k = self.to_k(context_k)
- v = self.to_v(context_v)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
-
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
-
- if mask is not None:
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
-
- # attention, what we cannot get enough of
- attn = sim.softmax(dim=-1)
-
- out = einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return self.to_out(out)
-
-
-def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
- assert hypernetwork_name, 'embedding not selected'
-
- path = shared.hypernetworks.get(hypernetwork_name, None)
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
-
- shared.state.textinfo = "Initializing hypernetwork training..."
- shared.state.job_count = steps
-
- filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
-
- log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
-
- if save_hypernetwork_every > 0:
- hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
- os.makedirs(hypernetwork_dir, exist_ok=True)
- else:
- hypernetwork_dir = None
-
- if create_image_every > 0:
- images_dir = os.path.join(log_directory, "images")
- os.makedirs(images_dir, exist_ok=True)
- else:
- images_dir = None
-
- cond_model = shared.sd_model.cond_stage_model
-
- shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
- with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
-
- hypernetwork = shared.loaded_hypernetwork
- weights = hypernetwork.weights()
- for weight in weights:
- weight.requires_grad = True
-
- optimizer = torch.optim.AdamW(weights, lr=learn_rate)
-
- losses = torch.zeros((32,))
-
- last_saved_file = ""
- last_saved_image = ""
-
- ititial_step = hypernetwork.step or 0
- if ititial_step > steps:
- return hypernetwork, filename
-
- pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
- for i, (x, text) in pbar:
- hypernetwork.step = i + ititial_step
-
- if hypernetwork.step > steps:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = cond_model([text])
-
- x = x.to(devices.device)
- loss = shared.sd_model(x.unsqueeze(0), c)[0]
- del x
-
- losses[hypernetwork.step % losses.shape[0]] = loss.item()
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- pbar.set_description(f"loss: {losses.mean():.7f}")
-
- if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
- hypernetwork.save(last_saved_file)
-
- if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
-
- preview_text = text if preview_image_prompt == "" else preview_image_prompt
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- prompt=preview_text,
- steps=20,
- do_not_save_grid=True,
- do_not_save_samples=True,
- )
-
- processed = processing.process_images(p)
- image = processed.images[0]
-
- shared.state.current_image = image
- image.save(last_saved_image)
-
- last_saved_image += f", prompt: {preview_text}"
-
- shared.state.job_no = hypernetwork.step
-
- shared.state.textinfo = f"""
-
-Loss: {losses.mean():.7f}
-Step: {hypernetwork.step}
-Last prompt: {html.escape(text)}
-Last saved embedding: {html.escape(last_saved_file)}
-Last saved image: {html.escape(last_saved_image)}
-
-"""
-
- checkpoint = sd_models.select_checkpoint()
-
- hypernetwork.sd_checkpoint = checkpoint.hash
- hypernetwork.sd_checkpoint_name = checkpoint.model_name
- hypernetwork.save(filename)
-
- return hypernetwork, filename
-
-
diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py
deleted file mode 100644
index f6d1d0a3..00000000
--- a/modules/hypernetwork/ui.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import html
-import os
-
-import gradio as gr
-
-import modules.textual_inversion.textual_inversion
-import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared
-from modules.hypernetwork import hypernetwork
-
-
-def create_hypernetwork(name):
- fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
-
- hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
- hypernet.save(fn)
-
- shared.reload_hypernetworks()
-
- return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
-
-
-def train_hypernetwork(*args):
-
- initial_hypernetwork = shared.loaded_hypernetwork
-
- try:
- sd_hijack.undo_optimizations()
-
- hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args)
-
- res = f"""
-Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
-Hypernetwork saved to {html.escape(filename)}
-"""
- return res, ""
- except Exception:
- raise
- finally:
- shared.loaded_hypernetwork = initial_hypernetwork
- sd_hijack.apply_optimizations()
-
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
new file mode 100644
index 00000000..aa701bda
--- /dev/null
+++ b/modules/hypernetworks/hypernetwork.py
@@ -0,0 +1,283 @@
+import datetime
+import glob
+import html
+import os
+import sys
+import traceback
+import tqdm
+
+import torch
+
+from ldm.util import default
+from modules import devices, shared, processing, sd_models
+import torch
+from torch import einsum
+from einops import rearrange, repeat
+import modules.textual_inversion.dataset
+
+
+class HypernetworkModule(torch.nn.Module):
+ def __init__(self, dim, state_dict=None):
+ super().__init__()
+
+ self.linear1 = torch.nn.Linear(dim, dim * 2)
+ self.linear2 = torch.nn.Linear(dim * 2, dim)
+
+ if state_dict is not None:
+ self.load_state_dict(state_dict, strict=True)
+ else:
+
+ self.linear1.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear1.bias.data.zero_()
+ self.linear2.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear2.bias.data.zero_()
+
+ self.to(devices.device)
+
+ def forward(self, x):
+ return x + (self.linear2(self.linear1(x)))
+
+
+class Hypernetwork:
+ filename = None
+ name = None
+
+ def __init__(self, name=None):
+ self.filename = None
+ self.name = name
+ self.layers = {}
+ self.step = 0
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
+
+ for size in [320, 640, 768, 1280]:
+ self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
+
+ def weights(self):
+ res = []
+
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.train()
+ res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
+
+ return res
+
+ def save(self, filename):
+ state_dict = {}
+
+ for k, v in self.layers.items():
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
+
+ state_dict['step'] = self.step
+ state_dict['name'] = self.name
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+
+ torch.save(state_dict, filename)
+
+ def load(self, filename):
+ self.filename = filename
+ if self.name is None:
+ self.name = os.path.splitext(os.path.basename(filename))[0]
+
+ state_dict = torch.load(filename, map_location='cpu')
+
+ for size, sd in state_dict.items():
+ if type(size) == int:
+ self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
+
+ self.name = state_dict.get('name', self.name)
+ self.step = state_dict.get('step', 0)
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
+
+
+def list_hypernetworks(path):
+ res = {}
+ for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
+ name = os.path.splitext(os.path.basename(filename))[0]
+ res[name] = filename
+ return res
+
+
+def load_hypernetwork(filename):
+ path = shared.hypernetworks.get(filename, None)
+ if path is not None:
+ print(f"Loading hypernetwork {filename}")
+ try:
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
+ except Exception:
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ else:
+ if shared.loaded_hypernetwork is not None:
+ print(f"Unloading hypernetwork")
+
+ shared.loaded_hypernetwork = None
+
+
+def apply_hypernetwork(hypernetwork, context, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is None:
+ return context, context
+
+ if layer is not None:
+ layer.hyper_k = hypernetwork_layers[0]
+ layer.hyper_v = hypernetwork_layers[1]
+
+ context_k = hypernetwork_layers[0](context)
+ context_v = hypernetwork_layers[1](context)
+ return context_k, context_v
+
+
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
+ assert hypernetwork_name, 'embedding not selected'
+
+ path = shared.hypernetworks.get(hypernetwork_name, None)
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
+ shared.state.textinfo = "Initializing hypernetwork training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
+
+ if save_hypernetwork_every > 0:
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
+ os.makedirs(hypernetwork_dir, exist_ok=True)
+ else:
+ hypernetwork_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ cond_model = shared.sd_model.cond_stage_model
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+
+ hypernetwork = shared.loaded_hypernetwork
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
+
+ optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+
+ losses = torch.zeros((32,))
+
+ last_saved_file = ""
+ last_saved_image = ""
+
+ ititial_step = hypernetwork.step or 0
+ if ititial_step > steps:
+ return hypernetwork, filename
+
+ pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
+ for i, (x, text) in pbar:
+ hypernetwork.step = i + ititial_step
+
+ if hypernetwork.step > steps:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([text])
+
+ x = x.to(devices.device)
+ loss = shared.sd_model(x.unsqueeze(0), c)[0]
+ del x
+
+ losses[hypernetwork.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ pbar.set_description(f"loss: {losses.mean():.7f}")
+
+ if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
+ hypernetwork.save(last_saved_file)
+
+ if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
+ last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+
+ preview_text = text if preview_image_prompt == "" else preview_image_prompt
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ prompt=preview_text,
+ steps=20,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ processed = processing.process_images(p)
+ image = processed.images[0]
+
+ shared.state.current_image = image
+ image.save(last_saved_image)
+
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = hypernetwork.step
+
+ shared.state.textinfo = f"""
+
+Loss: {losses.mean():.7f}
+Step: {hypernetwork.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+
+"""
+
+ checkpoint = sd_models.select_checkpoint()
+
+ hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ hypernetwork.save(filename)
+
+ return hypernetwork, filename
+
+
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
new file mode 100644
index 00000000..811bc31e
--- /dev/null
+++ b/modules/hypernetworks/ui.py
@@ -0,0 +1,43 @@
+import html
+import os
+
+import gradio as gr
+
+import modules.textual_inversion.textual_inversion
+import modules.textual_inversion.preprocess
+from modules import sd_hijack, shared
+from modules.hypernetworks import hypernetwork
+
+
+def create_hypernetwork(name):
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
+ hypernet.save(fn)
+
+ shared.reload_hypernetworks()
+
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
+
+
+def train_hypernetwork(*args):
+
+ initial_hypernetwork = shared.loaded_hypernetwork
+
+ try:
+ sd_hijack.undo_optimizations()
+
+ hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args)
+
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
+Hypernetwork saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception:
+ raise
+ finally:
+ shared.loaded_hypernetwork = initial_hypernetwork
+ sd_hijack.apply_optimizations()
+
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index f873049a..f07ec041 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -37,7 +37,7 @@ def apply_optimizations():
def undo_optimizations():
- from modules.hypernetwork import hypernetwork
+ from modules.hypernetworks import hypernetwork
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 27e571fc..3349b9c3 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -9,7 +9,7 @@ from ldm.util import default
from einops import rearrange
from modules import shared
-from modules.hypernetwork import hypernetwork
+from modules.hypernetworks import hypernetwork
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
diff --git a/modules/shared.py b/modules/shared.py
index 375e3afb..1dc2ccf2 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -14,7 +14,7 @@ import modules.sd_models
import modules.styles
import modules.devices as devices
from modules import sd_samplers
-from modules.hypernetwork import hypernetwork
+from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
diff --git a/modules/ui.py b/modules/ui.py
index f57f32db..42e5d866 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -39,7 +39,7 @@ import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
-import modules.hypernetwork.ui
+import modules.hypernetworks.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 16918c99..cddb192a 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -11,7 +11,7 @@ import modules.scripts as scripts
import gradio as gr
from modules import images
-from modules.hypernetwork import hypernetwork
+from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
diff --git a/webui.py b/webui.py
index ba2156c8..faa38a0d 100644
--- a/webui.py
+++ b/webui.py
@@ -29,7 +29,7 @@ from modules import devices
from modules import modelloader
from modules.paths import script_path
from modules.shared import cmd_opts
-import modules.hypernetwork.hypernetwork
+import modules.hypernetworks.hypernetwork
modelloader.cleanup_models()
modules.sd_models.setup_model()
--
cgit v1.2.3
From b0583be0884cd17dafb408fd79b52b2a0a972563 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 11 Oct 2022 15:54:34 +0300
Subject: more renames
---
modules/hypernetworks/ui.py | 4 ++--
modules/ui.py | 4 ++--
webui.py | 2 +-
3 files changed, 5 insertions(+), 5 deletions(-)
(limited to 'modules/ui.py')
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 811bc31e..e7540f41 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -13,7 +13,7 @@ def create_hypernetwork(name):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
- hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name)
hypernet.save(fn)
shared.reload_hypernetworks()
@@ -28,7 +28,7 @@ def train_hypernetwork(*args):
try:
sd_hijack.undo_optimizations()
- hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args)
+ hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
diff --git a/modules/ui.py b/modules/ui.py
index 42e5d866..ee333c3b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1111,7 +1111,7 @@ def create_ui(wrap_gradio_gpu_call):
)
create_hypernetwork.click(
- fn=modules.hypernetwork.ui.create_hypernetwork,
+ fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[
new_hypernetwork_name,
],
@@ -1164,7 +1164,7 @@ def create_ui(wrap_gradio_gpu_call):
)
train_hypernetwork.click(
- fn=wrap_gradio_gpu_call(modules.hypernetwork.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+ fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
diff --git a/webui.py b/webui.py
index faa38a0d..338f58e1 100644
--- a/webui.py
+++ b/webui.py
@@ -83,7 +83,7 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
-shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
+shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
def webui():
--
cgit v1.2.3
From d01a2d01560b31937df1f3433d210c18f97d32fa Mon Sep 17 00:00:00 2001
From: papuSpartan
Date: Tue, 11 Oct 2022 08:03:31 -0500
Subject: move list refresh to webui.py and add stdout indicating it's doing so
---
modules/ui.py | 3 ---
webui.py | 2 ++
2 files changed, 2 insertions(+), 3 deletions(-)
(limited to 'modules/ui.py')
diff --git a/modules/ui.py b/modules/ui.py
index 06ff118f..ae9317a3 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -39,7 +39,6 @@ import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
-from modules.sd_models import list_models
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -1291,8 +1290,6 @@ Requested path was: {f}
shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True
- # refresh models so that new models/.ckpt's show up on reload
- list_models()
restart_gradio.click(
fn=request_restart,
diff --git a/webui.py b/webui.py
index 270584f7..94098c4c 100644
--- a/webui.py
+++ b/webui.py
@@ -124,6 +124,8 @@ def webui():
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
print('Reloading modules: modules.ui')
importlib.reload(modules.ui)
+ print('Refreshing Model List')
+ modules.sd_models.list_models()
print('Restarting Gradio')
--
cgit v1.2.3
From d682444ecc99319fbd2b142a12727501e2884ba7 Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Tue, 11 Oct 2022 18:04:47 +0300
Subject: add option to select hypernetwork modules when creating
---
modules/hypernetworks/hypernetwork.py | 4 ++--
modules/hypernetworks/ui.py | 4 ++--
modules/ui.py | 2 ++
3 files changed, 6 insertions(+), 4 deletions(-)
(limited to 'modules/ui.py')
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index aa701bda..b081f14e 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -42,7 +42,7 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None):
+ def __init__(self, name=None, enable_sizes=None):
self.filename = None
self.name = name
self.layers = {}
@@ -50,7 +50,7 @@ class Hypernetwork:
self.sd_checkpoint = None
self.sd_checkpoint_name = None
- for size in [320, 640, 768, 1280]:
+ for size in enable_sizes or [320, 640, 768, 1280]:
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
def weights(self):
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index e7540f41..cdddcce1 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -9,11 +9,11 @@ from modules import sd_hijack, shared
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name):
+def create_hypernetwork(name, enable_sizes):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name)
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
hypernet.save(fn)
shared.reload_hypernetworks()
diff --git a/modules/ui.py b/modules/ui.py
index f2d16b12..14b87b92 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1037,6 +1037,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="Create a new hypernetwork
")
new_hypernetwork_name = gr.Textbox(label="Name")
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
with gr.Row():
with gr.Column(scale=3):
@@ -1114,6 +1115,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[
new_hypernetwork_name,
+ new_hypernetwork_sizes,
],
outputs=[
train_hypernetwork_name,
--
cgit v1.2.3